SagemakerRuntime: invoke_endpoint_async() implementation (#7211)
This commit is contained in:
parent
455fbd5eaa
commit
3b3f718d41
@ -1,6 +1,8 @@
|
||||
import json
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from moto.core import BackendDict, BaseBackend
|
||||
from moto.moto_api._internal import mock_random as random
|
||||
|
||||
|
||||
class SageMakerRuntimeBackend(BaseBackend):
|
||||
@ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend):
|
||||
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.async_results: Dict[str, Dict[str, str]] = {}
|
||||
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
|
||||
self.results_queue: List[Tuple[str, str, str, str]] = []
|
||||
|
||||
@ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend):
|
||||
)
|
||||
return self.results[endpoint_name][unique_repr]
|
||||
|
||||
def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str:
|
||||
if endpoint_name not in self.async_results:
|
||||
self.async_results[endpoint_name] = {}
|
||||
if input_location in self.async_results[endpoint_name]:
|
||||
return self.async_results[endpoint_name][input_location]
|
||||
if self.results_queue:
|
||||
body, _type, variant, attrs = self.results_queue.pop(0)
|
||||
else:
|
||||
body = "body"
|
||||
_type = "content_type"
|
||||
variant = "invoked_production_variant"
|
||||
attrs = "custom_attributes"
|
||||
json_data = {
|
||||
"Body": body,
|
||||
"ContentType": _type,
|
||||
"InvokedProductionVariant": variant,
|
||||
"CustomAttributes": attrs,
|
||||
}
|
||||
output = json.dumps(json_data).encode("utf-8")
|
||||
|
||||
output_bucket = f"sagemaker-output-{random.uuid4()}"
|
||||
output_location = "response.json"
|
||||
from moto.s3.models import s3_backends
|
||||
|
||||
s3_backend = s3_backends[self.account_id]["global"]
|
||||
s3_backend.create_bucket(output_bucket, region_name=self.region_name)
|
||||
s3_backend.put_object(output_bucket, output_location, value=output)
|
||||
self.async_results[endpoint_name][
|
||||
input_location
|
||||
] = f"s3://{output_bucket}/{output_location}"
|
||||
return self.async_results[endpoint_name][input_location]
|
||||
|
||||
|
||||
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
|
||||
from moto.core.common_types import TYPE_RESPONSE
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.moto_api._internal import mock_random as random
|
||||
|
||||
from .models import SageMakerRuntimeBackend, sagemakerruntime_backends
|
||||
|
||||
@ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse):
|
||||
if custom_attributes:
|
||||
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
|
||||
return 200, headers, body
|
||||
|
||||
def invoke_endpoint_async(self) -> TYPE_RESPONSE:
|
||||
endpoint_name = self.path.split("/")[2]
|
||||
input_location = self.headers.get("X-Amzn-SageMaker-InputLocation")
|
||||
inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id")
|
||||
location = self.sagemakerruntime_backend.invoke_endpoint_async(
|
||||
endpoint_name, input_location
|
||||
)
|
||||
resp = {"InferenceId": inference_id or str(random.uuid4())}
|
||||
headers = {"X-Amzn-SageMaker-OutputLocation": location}
|
||||
return 200, headers, json.dumps(resp)
|
||||
|
@ -10,5 +10,6 @@ response = SageMakerRuntimeResponse()
|
||||
|
||||
|
||||
url_paths = {
|
||||
"{0}/endpoints/(?P<name>[^/]+)/async-invocations$": response.dispatch,
|
||||
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
|
||||
}
|
||||
|
@ -1,7 +1,10 @@
|
||||
import json
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
|
||||
from moto import mock_sagemakerruntime, settings
|
||||
from moto import mock_s3, mock_sagemakerruntime, settings
|
||||
from moto.s3.utils import bucket_and_name_from_url
|
||||
|
||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||
@ -54,3 +57,49 @@ def test_invoke_endpoint():
|
||||
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||
)
|
||||
assert body["Body"].read() == b"second body"
|
||||
|
||||
|
||||
@mock_s3
|
||||
@mock_sagemakerruntime
|
||||
def test_invoke_endpoint_async():
|
||||
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
|
||||
base_url = (
|
||||
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
|
||||
)
|
||||
|
||||
sagemaker_result = {
|
||||
"results": [
|
||||
{
|
||||
"Body": "first body",
|
||||
"ContentType": "text/xml",
|
||||
"InvokedProductionVariant": "prod",
|
||||
"CustomAttributes": "my_attr",
|
||||
},
|
||||
{"Body": "second body"},
|
||||
]
|
||||
}
|
||||
requests.post(
|
||||
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
|
||||
json=sagemaker_result,
|
||||
)
|
||||
|
||||
# Return the first item from the list
|
||||
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
|
||||
first_output_location = body["OutputLocation"]
|
||||
|
||||
# Same input -> same output
|
||||
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
|
||||
assert body["OutputLocation"] == first_output_location
|
||||
|
||||
# Different input -> second item
|
||||
body = client.invoke_endpoint_async(
|
||||
EndpointName="asdf", InputLocation="asf", InferenceId="sth"
|
||||
)
|
||||
second_output_location = body["OutputLocation"]
|
||||
assert body["InferenceId"] == "sth"
|
||||
|
||||
s3 = boto3.client("s3", "us-east-1")
|
||||
bucket_name, obj = bucket_and_name_from_url(second_output_location)
|
||||
resp = s3.get_object(Bucket=bucket_name, Key=obj)
|
||||
resp = json.loads(resp["Body"].read().decode("utf-8"))
|
||||
assert resp["Body"] == "second body"
|
||||
|
Loading…
Reference in New Issue
Block a user