SagemakerRuntime: invoke_endpoint_async() implementation (#7211)
This commit is contained in:
		
							parent
							
								
									455fbd5eaa
								
							
						
					
					
						commit
						3b3f718d41
					
				| @ -1,6 +1,8 @@ | |||||||
|  | import json | ||||||
| from typing import Dict, List, Tuple | from typing import Dict, List, Tuple | ||||||
| 
 | 
 | ||||||
| from moto.core import BackendDict, BaseBackend | from moto.core import BackendDict, BaseBackend | ||||||
|  | from moto.moto_api._internal import mock_random as random | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SageMakerRuntimeBackend(BaseBackend): | class SageMakerRuntimeBackend(BaseBackend): | ||||||
| @ -8,6 +10,7 @@ class SageMakerRuntimeBackend(BaseBackend): | |||||||
| 
 | 
 | ||||||
|     def __init__(self, region_name: str, account_id: str): |     def __init__(self, region_name: str, account_id: str): | ||||||
|         super().__init__(region_name, account_id) |         super().__init__(region_name, account_id) | ||||||
|  |         self.async_results: Dict[str, Dict[str, str]] = {} | ||||||
|         self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} |         self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} | ||||||
|         self.results_queue: List[Tuple[str, str, str, str]] = [] |         self.results_queue: List[Tuple[str, str, str, str]] = [] | ||||||
| 
 | 
 | ||||||
| @ -62,5 +65,37 @@ class SageMakerRuntimeBackend(BaseBackend): | |||||||
|             ) |             ) | ||||||
|         return self.results[endpoint_name][unique_repr] |         return self.results[endpoint_name][unique_repr] | ||||||
| 
 | 
 | ||||||
|  |     def invoke_endpoint_async(self, endpoint_name: str, input_location: str) -> str: | ||||||
|  |         if endpoint_name not in self.async_results: | ||||||
|  |             self.async_results[endpoint_name] = {} | ||||||
|  |         if input_location in self.async_results[endpoint_name]: | ||||||
|  |             return self.async_results[endpoint_name][input_location] | ||||||
|  |         if self.results_queue: | ||||||
|  |             body, _type, variant, attrs = self.results_queue.pop(0) | ||||||
|  |         else: | ||||||
|  |             body = "body" | ||||||
|  |             _type = "content_type" | ||||||
|  |             variant = "invoked_production_variant" | ||||||
|  |             attrs = "custom_attributes" | ||||||
|  |         json_data = { | ||||||
|  |             "Body": body, | ||||||
|  |             "ContentType": _type, | ||||||
|  |             "InvokedProductionVariant": variant, | ||||||
|  |             "CustomAttributes": attrs, | ||||||
|  |         } | ||||||
|  |         output = json.dumps(json_data).encode("utf-8") | ||||||
|  | 
 | ||||||
|  |         output_bucket = f"sagemaker-output-{random.uuid4()}" | ||||||
|  |         output_location = "response.json" | ||||||
|  |         from moto.s3.models import s3_backends | ||||||
|  | 
 | ||||||
|  |         s3_backend = s3_backends[self.account_id]["global"] | ||||||
|  |         s3_backend.create_bucket(output_bucket, region_name=self.region_name) | ||||||
|  |         s3_backend.put_object(output_bucket, output_location, value=output) | ||||||
|  |         self.async_results[endpoint_name][ | ||||||
|  |             input_location | ||||||
|  |         ] = f"s3://{output_bucket}/{output_location}" | ||||||
|  |         return self.async_results[endpoint_name][input_location] | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") | sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ import json | |||||||
| 
 | 
 | ||||||
| from moto.core.common_types import TYPE_RESPONSE | from moto.core.common_types import TYPE_RESPONSE | ||||||
| from moto.core.responses import BaseResponse | from moto.core.responses import BaseResponse | ||||||
|  | from moto.moto_api._internal import mock_random as random | ||||||
| 
 | 
 | ||||||
| from .models import SageMakerRuntimeBackend, sagemakerruntime_backends | from .models import SageMakerRuntimeBackend, sagemakerruntime_backends | ||||||
| 
 | 
 | ||||||
| @ -43,3 +44,14 @@ class SageMakerRuntimeResponse(BaseResponse): | |||||||
|         if custom_attributes: |         if custom_attributes: | ||||||
|             headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes |             headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes | ||||||
|         return 200, headers, body |         return 200, headers, body | ||||||
|  | 
 | ||||||
|  |     def invoke_endpoint_async(self) -> TYPE_RESPONSE: | ||||||
|  |         endpoint_name = self.path.split("/")[2] | ||||||
|  |         input_location = self.headers.get("X-Amzn-SageMaker-InputLocation") | ||||||
|  |         inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id") | ||||||
|  |         location = self.sagemakerruntime_backend.invoke_endpoint_async( | ||||||
|  |             endpoint_name, input_location | ||||||
|  |         ) | ||||||
|  |         resp = {"InferenceId": inference_id or str(random.uuid4())} | ||||||
|  |         headers = {"X-Amzn-SageMaker-OutputLocation": location} | ||||||
|  |         return 200, headers, json.dumps(resp) | ||||||
|  | |||||||
| @ -10,5 +10,6 @@ response = SageMakerRuntimeResponse() | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| url_paths = { | url_paths = { | ||||||
|  |     "{0}/endpoints/(?P<name>[^/]+)/async-invocations$": response.dispatch, | ||||||
|     "{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch, |     "{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch, | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,7 +1,10 @@ | |||||||
|  | import json | ||||||
|  | 
 | ||||||
| import boto3 | import boto3 | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
| from moto import mock_sagemakerruntime, settings | from moto import mock_s3, mock_sagemakerruntime, settings | ||||||
|  | from moto.s3.utils import bucket_and_name_from_url | ||||||
| 
 | 
 | ||||||
| # See our Development Tips on writing tests for hints on how to write good tests: | # See our Development Tips on writing tests for hints on how to write good tests: | ||||||
| # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html | # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html | ||||||
| @ -54,3 +57,49 @@ def test_invoke_endpoint(): | |||||||
|         EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" |         EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" | ||||||
|     ) |     ) | ||||||
|     assert body["Body"].read() == b"second body" |     assert body["Body"].read() == b"second body" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mock_s3 | ||||||
|  | @mock_sagemakerruntime | ||||||
|  | def test_invoke_endpoint_async(): | ||||||
|  |     client = boto3.client("sagemaker-runtime", region_name="us-east-1") | ||||||
|  |     base_url = ( | ||||||
|  |         "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     sagemaker_result = { | ||||||
|  |         "results": [ | ||||||
|  |             { | ||||||
|  |                 "Body": "first body", | ||||||
|  |                 "ContentType": "text/xml", | ||||||
|  |                 "InvokedProductionVariant": "prod", | ||||||
|  |                 "CustomAttributes": "my_attr", | ||||||
|  |             }, | ||||||
|  |             {"Body": "second body"}, | ||||||
|  |         ] | ||||||
|  |     } | ||||||
|  |     requests.post( | ||||||
|  |         f"http://{base_url}/moto-api/static/sagemaker/endpoint-results", | ||||||
|  |         json=sagemaker_result, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Return the first item from the list | ||||||
|  |     body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") | ||||||
|  |     first_output_location = body["OutputLocation"] | ||||||
|  | 
 | ||||||
|  |     # Same input -> same output | ||||||
|  |     body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer") | ||||||
|  |     assert body["OutputLocation"] == first_output_location | ||||||
|  | 
 | ||||||
|  |     # Different input -> second item | ||||||
|  |     body = client.invoke_endpoint_async( | ||||||
|  |         EndpointName="asdf", InputLocation="asf", InferenceId="sth" | ||||||
|  |     ) | ||||||
|  |     second_output_location = body["OutputLocation"] | ||||||
|  |     assert body["InferenceId"] == "sth" | ||||||
|  | 
 | ||||||
|  |     s3 = boto3.client("s3", "us-east-1") | ||||||
|  |     bucket_name, obj = bucket_and_name_from_url(second_output_location) | ||||||
|  |     resp = s3.get_object(Bucket=bucket_name, Key=obj) | ||||||
|  |     resp = json.loads(resp["Body"].read().decode("utf-8")) | ||||||
|  |     assert resp["Body"] == "second body" | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user