SagemakerRuntime: invoke_endpoint_async() implementation (#7211)
This commit is contained in:
parent
455fbd5eaa
commit
3b3f718d41
@ -1,6 +1,8 @@
|
|||||||
|
import json
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from moto.core import BackendDict, BaseBackend
|
from moto.core import BackendDict, BaseBackend
|
||||||
|
from moto.moto_api._internal import mock_random as random
|
||||||
|
|
||||||
|
|
||||||
class SageMakerRuntimeBackend(BaseBackend):
|
class SageMakerRuntimeBackend(BaseBackend):
|
||||||
@ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend):
|
|||||||
|
|
||||||
def __init__(self, region_name: str, account_id: str):
|
def __init__(self, region_name: str, account_id: str):
|
||||||
super().__init__(region_name, account_id)
|
super().__init__(region_name, account_id)
|
||||||
|
self.async_results: Dict[str, Dict[str, str]] = {}
|
||||||
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
|
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
|
||||||
self.results_queue: List[Tuple[str, str, str, str]] = []
|
self.results_queue: List[Tuple[str, str, str, str]] = []
|
||||||
|
|
||||||
@ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend):
|
|||||||
)
|
)
|
||||||
return self.results[endpoint_name][unique_repr]
|
return self.results[endpoint_name][unique_repr]
|
||||||
|
|
||||||
|
def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str:
|
||||||
|
if endpoint_name not in self.async_results:
|
||||||
|
self.async_results[endpoint_name] = {}
|
||||||
|
if input_location in self.async_results[endpoint_name]:
|
||||||
|
return self.async_results[endpoint_name][input_location]
|
||||||
|
if self.results_queue:
|
||||||
|
body, _type, variant, attrs = self.results_queue.pop(0)
|
||||||
|
else:
|
||||||
|
body = "body"
|
||||||
|
_type = "content_type"
|
||||||
|
variant = "invoked_production_variant"
|
||||||
|
attrs = "custom_attributes"
|
||||||
|
json_data = {
|
||||||
|
"Body": body,
|
||||||
|
"ContentType": _type,
|
||||||
|
"InvokedProductionVariant": variant,
|
||||||
|
"CustomAttributes": attrs,
|
||||||
|
}
|
||||||
|
output = json.dumps(json_data).encode("utf-8")
|
||||||
|
|
||||||
|
output_bucket = f"sagemaker-output-{random.uuid4()}"
|
||||||
|
output_location = "response.json"
|
||||||
|
from moto.s3.models import s3_backends
|
||||||
|
|
||||||
|
s3_backend = s3_backends[self.account_id]["global"]
|
||||||
|
s3_backend.create_bucket(output_bucket, region_name=self.region_name)
|
||||||
|
s3_backend.put_object(output_bucket, output_location, value=output)
|
||||||
|
self.async_results[endpoint_name][
|
||||||
|
input_location
|
||||||
|
] = f"s3://{output_bucket}/{output_location}"
|
||||||
|
return self.async_results[endpoint_name][input_location]
|
||||||
|
|
||||||
|
|
||||||
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
|
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
|
||||||
|
@ -3,6 +3,7 @@ import json
|
|||||||
|
|
||||||
from moto.core.common_types import TYPE_RESPONSE
|
from moto.core.common_types import TYPE_RESPONSE
|
||||||
from moto.core.responses import BaseResponse
|
from moto.core.responses import BaseResponse
|
||||||
|
from moto.moto_api._internal import mock_random as random
|
||||||
|
|
||||||
from .models import SageMakerRuntimeBackend, sagemakerruntime_backends
|
from .models import SageMakerRuntimeBackend, sagemakerruntime_backends
|
||||||
|
|
||||||
@ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse):
|
|||||||
if custom_attributes:
|
if custom_attributes:
|
||||||
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
|
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
|
||||||
return 200, headers, body
|
return 200, headers, body
|
||||||
|
|
||||||
|
def invoke_endpoint_async(self) -> TYPE_RESPONSE:
|
||||||
|
endpoint_name = self.path.split("/")[2]
|
||||||
|
input_location = self.headers.get("X-Amzn-SageMaker-InputLocation")
|
||||||
|
inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id")
|
||||||
|
location = self.sagemakerruntime_backend.invoke_endpoint_async(
|
||||||
|
endpoint_name, input_location
|
||||||
|
)
|
||||||
|
resp = {"InferenceId": inference_id or str(random.uuid4())}
|
||||||
|
headers = {"X-Amzn-SageMaker-OutputLocation": location}
|
||||||
|
return 200, headers, json.dumps(resp)
|
||||||
|
@ -10,5 +10,6 @@ response = SageMakerRuntimeResponse()
|
|||||||
|
|
||||||
|
|
||||||
url_paths = {
|
url_paths = {
|
||||||
|
"{0}/endpoints/(?P<name>[^/]+)/async-invocations$": response.dispatch,
|
||||||
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
|
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from moto import mock_sagemakerruntime, settings
|
from moto import mock_s3, mock_sagemakerruntime, settings
|
||||||
|
from moto.s3.utils import bucket_and_name_from_url
|
||||||
|
|
||||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||||
@ -54,3 +57,49 @@ def test_invoke_endpoint():
|
|||||||
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||||
)
|
)
|
||||||
assert body["Body"].read() == b"second body"
|
assert body["Body"].read() == b"second body"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_s3
|
||||||
|
@mock_sagemakerruntime
|
||||||
|
def test_invoke_endpoint_async():
|
||||||
|
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
|
||||||
|
base_url = (
|
||||||
|
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
sagemaker_result = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"Body": "first body",
|
||||||
|
"ContentType": "text/xml",
|
||||||
|
"InvokedProductionVariant": "prod",
|
||||||
|
"CustomAttributes": "my_attr",
|
||||||
|
},
|
||||||
|
{"Body": "second body"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
requests.post(
|
||||||
|
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
|
||||||
|
json=sagemaker_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the first item from the list
|
||||||
|
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
|
||||||
|
first_output_location = body["OutputLocation"]
|
||||||
|
|
||||||
|
# Same input -> same output
|
||||||
|
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
|
||||||
|
assert body["OutputLocation"] == first_output_location
|
||||||
|
|
||||||
|
# Different input -> second item
|
||||||
|
body = client.invoke_endpoint_async(
|
||||||
|
EndpointName="asdf", InputLocation="asf", InferenceId="sth"
|
||||||
|
)
|
||||||
|
second_output_location = body["OutputLocation"]
|
||||||
|
assert body["InferenceId"] == "sth"
|
||||||
|
|
||||||
|
s3 = boto3.client("s3", "us-east-1")
|
||||||
|
bucket_name, obj = bucket_and_name_from_url(second_output_location)
|
||||||
|
resp = s3.get_object(Bucket=bucket_name, Key=obj)
|
||||||
|
resp = json.loads(resp["Body"].read().decode("utf-8"))
|
||||||
|
assert resp["Body"] == "second body"
|
||||||
|
Loading…
Reference in New Issue
Block a user