SagemakerRuntime: invoke_endpoint_async() implementation (#7211)

This commit is contained in:
Bert Blommers 2024-01-13 21:09:57 +00:00 committed by GitHub
parent 455fbd5eaa
commit 3b3f718d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 1 deletions

View File

@ -1,6 +1,8 @@
import json
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from moto.core import BackendDict, BaseBackend from moto.core import BackendDict, BaseBackend
from moto.moto_api._internal import mock_random as random
class SageMakerRuntimeBackend(BaseBackend): class SageMakerRuntimeBackend(BaseBackend):
@ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend):
def __init__(self, region_name: str, account_id: str): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.async_results: Dict[str, Dict[str, str]] = {}
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
self.results_queue: List[Tuple[str, str, str, str]] = [] self.results_queue: List[Tuple[str, str, str, str]] = []
@ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend):
) )
return self.results[endpoint_name][unique_repr] return self.results[endpoint_name][unique_repr]
def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str:
if endpoint_name not in self.async_results:
self.async_results[endpoint_name] = {}
if input_location in self.async_results[endpoint_name]:
return self.async_results[endpoint_name][input_location]
if self.results_queue:
body, _type, variant, attrs = self.results_queue.pop(0)
else:
body = "body"
_type = "content_type"
variant = "invoked_production_variant"
attrs = "custom_attributes"
json_data = {
"Body": body,
"ContentType": _type,
"InvokedProductionVariant": variant,
"CustomAttributes": attrs,
}
output = json.dumps(json_data).encode("utf-8")
output_bucket = f"sagemaker-output-{random.uuid4()}"
output_location = "response.json"
from moto.s3.models import s3_backends
s3_backend = s3_backends[self.account_id]["global"]
s3_backend.create_bucket(output_bucket, region_name=self.region_name)
s3_backend.put_object(output_bucket, output_location, value=output)
self.async_results[endpoint_name][
input_location
] = f"s3://{output_bucket}/{output_location}"
return self.async_results[endpoint_name][input_location]
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")

View File

@ -3,6 +3,7 @@ import json
from moto.core.common_types import TYPE_RESPONSE from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.moto_api._internal import mock_random as random
from .models import SageMakerRuntimeBackend, sagemakerruntime_backends from .models import SageMakerRuntimeBackend, sagemakerruntime_backends
@ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse):
if custom_attributes: if custom_attributes:
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
return 200, headers, body return 200, headers, body
def invoke_endpoint_async(self) -> TYPE_RESPONSE:
endpoint_name = self.path.split("/")[2]
input_location = self.headers.get("X-Amzn-SageMaker-InputLocation")
inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id")
location = self.sagemakerruntime_backend.invoke_endpoint_async(
endpoint_name, input_location
)
resp = {"InferenceId": inference_id or str(random.uuid4())}
headers = {"X-Amzn-SageMaker-OutputLocation": location}
return 200, headers, json.dumps(resp)

View File

@ -10,5 +10,6 @@ response = SageMakerRuntimeResponse()
url_paths = { url_paths = {
"{0}/endpoints/(?P<name>[^/]+)/async-invocations$": response.dispatch,
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch, "{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
} }

View File

@ -1,7 +1,10 @@
import json
import boto3 import boto3
import requests import requests
from moto import mock_sagemakerruntime, settings from moto import mock_s3, mock_sagemakerruntime, settings
from moto.s3.utils import bucket_and_name_from_url
# See our Development Tips on writing tests for hints on how to write good tests: # See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@ -54,3 +57,49 @@ def test_invoke_endpoint():
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
) )
assert body["Body"].read() == b"second body" assert body["Body"].read() == b"second body"
@mock_s3
@mock_sagemakerruntime
def test_invoke_endpoint_async():
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
base_url = (
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
)
sagemaker_result = {
"results": [
{
"Body": "first body",
"ContentType": "text/xml",
"InvokedProductionVariant": "prod",
"CustomAttributes": "my_attr",
},
{"Body": "second body"},
]
}
requests.post(
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
json=sagemaker_result,
)
# Return the first item from the list
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
first_output_location = body["OutputLocation"]
# Same input -> same output
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
assert body["OutputLocation"] == first_output_location
# Different input -> second item
body = client.invoke_endpoint_async(
EndpointName="asdf", InputLocation="asf", InferenceId="sth"
)
second_output_location = body["OutputLocation"]
assert body["InferenceId"] == "sth"
s3 = boto3.client("s3", "us-east-1")
bucket_name, obj = bucket_and_name_from_url(second_output_location)
resp = s3.get_object(Bucket=bucket_name, Key=obj)
resp = json.loads(resp["Body"].read().decode("utf-8"))
assert resp["Body"] == "second body"