From 3b3f718d41b7c182db4fc59c4a09324b5abc7922 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 13 Jan 2024 21:09:57 +0000 Subject: [PATCH] SagemakerRuntime: invoke_endpoint_async() implementation (#7211) --- moto/sagemakerruntime/models.py | 35 +++++++++++++ moto/sagemakerruntime/responses.py | 12 +++++ moto/sagemakerruntime/urls.py | 1 + .../test_sagemakerruntime.py | 51 ++++++++++++++++++- 4 files changed, 98 insertions(+), 1 deletion(-) diff --git a/moto/sagemakerruntime/models.py b/moto/sagemakerruntime/models.py index 26928bcf8..872ca4464 100644 --- a/moto/sagemakerruntime/models.py +++ b/moto/sagemakerruntime/models.py @@ -1,6 +1,8 @@ +import json from typing import Dict, List, Tuple from moto.core import BackendDict, BaseBackend +from moto.moto_api._internal import mock_random as random class SageMakerRuntimeBackend(BaseBackend): @@ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) + self.async_results: Dict[str, Dict[str, str]] = {} self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} self.results_queue: List[Tuple[str, str, str, str]] = [] @@ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend): ) return self.results[endpoint_name][unique_repr] + def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str: + if endpoint_name not in self.async_results: + self.async_results[endpoint_name] = {} + if input_location in self.async_results[endpoint_name]: + return self.async_results[endpoint_name][input_location] + if self.results_queue: + body, _type, variant, attrs = self.results_queue.pop(0) + else: + body = "body" + _type = "content_type" + variant = "invoked_production_variant" + attrs = "custom_attributes" + json_data = { + "Body": body, + "ContentType": _type, + "InvokedProductionVariant": variant, + "CustomAttributes": attrs, + } + output = json.dumps(json_data).encode("utf-8") + + output_bucket = f"sagemaker-output-{random.uuid4()}" + output_location = "response.json" + from moto.s3.models import s3_backends + + s3_backend = s3_backends[self.account_id]["global"] + s3_backend.create_bucket(output_bucket, region_name=self.region_name) + s3_backend.put_object(output_bucket, output_location, value=output) + self.async_results[endpoint_name][ + input_location + ] = f"s3://{output_bucket}/{output_location}" + return self.async_results[endpoint_name][input_location] + sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") diff --git a/moto/sagemakerruntime/responses.py b/moto/sagemakerruntime/responses.py index ee1b86608..c30c4b2b6 100644 --- a/moto/sagemakerruntime/responses.py +++ b/moto/sagemakerruntime/responses.py @@ -3,6 +3,7 @@ import json from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse +from moto.moto_api._internal import mock_random as random from .models import SageMakerRuntimeBackend, sagemakerruntime_backends @@ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse): if custom_attributes: headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes return 200, headers, body + + def invoke_endpoint_async(self) -> TYPE_RESPONSE: + endpoint_name = self.path.split("/")[2] + input_location = self.headers.get("X-Amzn-SageMaker-InputLocation") + inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id") + location = self.sagemakerruntime_backend.invoke_endpoint_async( + endpoint_name, input_location + ) + resp = {"InferenceId": inference_id or str(random.uuid4())} + headers = {"X-Amzn-SageMaker-OutputLocation": location} + return 200, headers, json.dumps(resp) diff --git a/moto/sagemakerruntime/urls.py b/moto/sagemakerruntime/urls.py index 01c753437..f8b6976ae 100644 --- a/moto/sagemakerruntime/urls.py +++ b/moto/sagemakerruntime/urls.py @@ -10,5 +10,6 @@ response = SageMakerRuntimeResponse() url_paths = { + "{0}/endpoints/(?P[^/]+)/async-invocations$": response.dispatch, "{0}/endpoints/(?P[^/]+)/invocations$": response.dispatch, } diff --git a/tests/test_sagemakerruntime/test_sagemakerruntime.py b/tests/test_sagemakerruntime/test_sagemakerruntime.py index 60b256d5a..9ccece62a 100644 --- a/tests/test_sagemakerruntime/test_sagemakerruntime.py +++ b/tests/test_sagemakerruntime/test_sagemakerruntime.py @@ -1,7 +1,10 @@ +import json + import boto3 import requests -from moto import mock_sagemakerruntime, settings +from moto import mock_s3, mock_sagemakerruntime, settings +from moto.s3.utils import bucket_and_name_from_url # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html @@ -54,3 +57,49 @@ def test_invoke_endpoint(): EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" ) assert body["Body"].read() == b"second body" + + +@mock_s3 +@mock_sagemakerruntime +def test_invoke_endpoint_async(): + client = boto3.client("sagemaker-runtime", region_name="us-east-1") + base_url = ( + "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" + ) + + sagemaker_result = { + "results": [ + { + "Body": "first body", + "ContentType": "text/xml", + "InvokedProductionVariant": "prod", + "CustomAttributes": "my_attr", + }, + {"Body": "second body"}, + ] + } + requests.post( + f"http://{base_url}/moto-api/static/sagemaker/endpoint-results", + json=sagemaker_result, + ) + + # Return the first item from the list + body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") + first_output_location = body["OutputLocation"] + + # Same input -> same output + body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") + assert body["OutputLocation"] == first_output_location + + # Different input -> second item + body = client.invoke_endpoint_async( + EndpointName="asdf", InputLocation="asf", InferenceId="sth" + ) + second_output_location = body["OutputLocation"] + assert body["InferenceId"] == "sth" + + s3 = boto3.client("s3", "us-east-1") + bucket_name, obj = bucket_and_name_from_url(second_output_location) + resp = s3.get_object(Bucket=bucket_name, Key=obj) + resp = json.loads(resp["Body"].read().decode("utf-8")) + assert resp["Body"] == "second body"