SagemakerRuntime: invoke_endpoint_async() implementation (#7211)
This commit is contained in:
		
							parent
							
								
									455fbd5eaa
								
							
						
					
					
						commit
						3b3f718d41
					
				| @ -1,6 +1,8 @@ | ||||
| import json | ||||
| from typing import Dict, List, Tuple | ||||
| 
 | ||||
| from moto.core import BackendDict, BaseBackend | ||||
| from moto.moto_api._internal import mock_random as random | ||||
| 
 | ||||
| 
 | ||||
| class SageMakerRuntimeBackend(BaseBackend): | ||||
| @ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend): | ||||
| 
 | ||||
|     def __init__(self, region_name: str, account_id: str): | ||||
|         super().__init__(region_name, account_id) | ||||
|         self.async_results: Dict[str, Dict[str, str]] = {} | ||||
|         self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} | ||||
|         self.results_queue: List[Tuple[str, str, str, str]] = [] | ||||
| 
 | ||||
| @ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend): | ||||
|             ) | ||||
|         return self.results[endpoint_name][unique_repr] | ||||
| 
 | ||||
|     def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str: | ||||
|         if endpoint_name not in self.async_results: | ||||
|             self.async_results[endpoint_name] = {} | ||||
|         if input_location in self.async_results[endpoint_name]: | ||||
|             return self.async_results[endpoint_name][input_location] | ||||
|         if self.results_queue: | ||||
|             body, _type, variant, attrs = self.results_queue.pop(0) | ||||
|         else: | ||||
|             body = "body" | ||||
|             _type = "content_type" | ||||
|             variant = "invoked_production_variant" | ||||
|             attrs = "custom_attributes" | ||||
|         json_data = { | ||||
|             "Body": body, | ||||
|             "ContentType": _type, | ||||
|             "InvokedProductionVariant": variant, | ||||
|             "CustomAttributes": attrs, | ||||
|         } | ||||
|         output = json.dumps(json_data).encode("utf-8") | ||||
| 
 | ||||
|         output_bucket = f"sagemaker-output-{random.uuid4()}" | ||||
|         output_location = "response.json" | ||||
|         from moto.s3.models import s3_backends | ||||
| 
 | ||||
|         s3_backend = s3_backends[self.account_id]["global"] | ||||
|         s3_backend.create_bucket(output_bucket, region_name=self.region_name) | ||||
|         s3_backend.put_object(output_bucket, output_location, value=output) | ||||
|         self.async_results[endpoint_name][ | ||||
|             input_location | ||||
|         ] = f"s3://{output_bucket}/{output_location}" | ||||
|         return self.async_results[endpoint_name][input_location] | ||||
| 
 | ||||
| 
 | ||||
| sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") | ||||
|  | ||||
| @ -3,6 +3,7 @@ import json | ||||
| 
 | ||||
| from moto.core.common_types import TYPE_RESPONSE | ||||
| from moto.core.responses import BaseResponse | ||||
| from moto.moto_api._internal import mock_random as random | ||||
| 
 | ||||
| from .models import SageMakerRuntimeBackend, sagemakerruntime_backends | ||||
| 
 | ||||
| @ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse): | ||||
|         if custom_attributes: | ||||
|             headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes | ||||
|         return 200, headers, body | ||||
| 
 | ||||
|     def invoke_endpoint_async(self) -> TYPE_RESPONSE: | ||||
|         endpoint_name = self.path.split("/")[2] | ||||
|         input_location = self.headers.get("X-Amzn-SageMaker-InputLocation") | ||||
|         inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id") | ||||
|         location = self.sagemakerruntime_backend.invoke_endpoint_async( | ||||
|             endpoint_name, input_location | ||||
|         ) | ||||
|         resp = {"InferenceId": inference_id or str(random.uuid4())} | ||||
|         headers = {"X-Amzn-SageMaker-OutputLocation": location} | ||||
|         return 200, headers, json.dumps(resp) | ||||
|  | ||||
| @ -10,5 +10,6 @@ response = SageMakerRuntimeResponse() | ||||
| 
 | ||||
| 
 | ||||
| url_paths = { | ||||
|     "{0}/endpoints/(?P<name>[^/]+)/async-invocations$": response.dispatch, | ||||
|     "{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch, | ||||
| } | ||||
|  | ||||
| @ -1,7 +1,10 @@ | ||||
| import json | ||||
| 
 | ||||
| import boto3 | ||||
| import requests | ||||
| 
 | ||||
| from moto import mock_sagemakerruntime, settings | ||||
| from moto import mock_s3, mock_sagemakerruntime, settings | ||||
| from moto.s3.utils import bucket_and_name_from_url | ||||
| 
 | ||||
| # See our Development Tips on writing tests for hints on how to write good tests: | ||||
| # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html | ||||
| @ -54,3 +57,49 @@ def test_invoke_endpoint(): | ||||
|         EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" | ||||
|     ) | ||||
|     assert body["Body"].read() == b"second body" | ||||
| 
 | ||||
| 
 | ||||
| @mock_s3 | ||||
| @mock_sagemakerruntime | ||||
| def test_invoke_endpoint_async(): | ||||
|     client = boto3.client("sagemaker-runtime", region_name="us-east-1") | ||||
|     base_url = ( | ||||
|         "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" | ||||
|     ) | ||||
| 
 | ||||
|     sagemaker_result = { | ||||
|         "results": [ | ||||
|             { | ||||
|                 "Body": "first body", | ||||
|                 "ContentType": "text/xml", | ||||
|                 "InvokedProductionVariant": "prod", | ||||
|                 "CustomAttributes": "my_attr", | ||||
|             }, | ||||
|             {"Body": "second body"}, | ||||
|         ] | ||||
|     } | ||||
|     requests.post( | ||||
|         f"http://{base_url}/moto-api/static/sagemaker/endpoint-results", | ||||
|         json=sagemaker_result, | ||||
|     ) | ||||
| 
 | ||||
|     # Return the first item from the list | ||||
|     body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") | ||||
|     first_output_location = body["OutputLocation"] | ||||
| 
 | ||||
|     # Same input -> same output | ||||
|     body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") | ||||
|     assert body["OutputLocation"] == first_output_location | ||||
| 
 | ||||
|     # Different input -> second item | ||||
|     body = client.invoke_endpoint_async( | ||||
|         EndpointName="asdf", InputLocation="asf", InferenceId="sth" | ||||
|     ) | ||||
|     second_output_location = body["OutputLocation"] | ||||
|     assert body["InferenceId"] == "sth" | ||||
| 
 | ||||
|     s3 = boto3.client("s3", "us-east-1") | ||||
|     bucket_name, obj = bucket_and_name_from_url(second_output_location) | ||||
|     resp = s3.get_object(Bucket=bucket_name, Key=obj) | ||||
|     resp = json.loads(resp["Body"].read().decode("utf-8")) | ||||
|     assert resp["Body"] == "second body" | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user