diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index a6a66e7ce..55921b3cf 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6437,6 +6437,14 @@ - [ ] update_workteam +## sagemaker-runtime +
+50% implemented + +- [X] invoke_endpoint +- [ ] invoke_endpoint_async +
+ ## scheduler
100% implemented @@ -7482,7 +7490,6 @@ - sagemaker-featurestore-runtime - sagemaker-geospatial - sagemaker-metrics -- sagemaker-runtime - savingsplans - schemas - securityhub diff --git a/docs/docs/services/sagemaker-runtime.rst b/docs/docs/services/sagemaker-runtime.rst new file mode 100644 index 000000000..eff4d9028 --- /dev/null +++ b/docs/docs/services/sagemaker-runtime.rst @@ -0,0 +1,66 @@ +.. _implementedservice_sagemaker-runtime: + +.. |start-h3| raw:: html + +

+ +.. |end-h3| raw:: html + +

+ +================= +sagemaker-runtime +================= + +.. autoclass:: moto.sagemakerruntime.models.SageMakerRuntimeBackend + +|start-h3| Example usage |end-h3| + +.. sourcecode:: python + + @mock_sagemakerruntime + def test_sagemakerruntime_behaviour: + boto3.client("sagemaker-runtime") + ... + + + +|start-h3| Implemented features for this service |end-h3| + +- [X] invoke_endpoint + + This call will return static data by default. + + You can use a dedicated API to override this, by configuring a queue of expected results. + + A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty. + + Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this: + + .. sourcecode:: python + + expected_results = { + "account_id": "123456789012", # This is the default - can be omitted + "region": "us-east-1", # This is the default - can be omitted + "results": [ + { + "Body": "first body", + "ContentType": "text/xml", + "InvokedProductionVariant": "prod", + "CustomAttributes": "my_attr", + }, + # other results as required + ], + } + requests.post( + "http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results", + json=expected_results, + ) + + client = boto3.client("sagemaker", region_name="us-east-1") + details = client.invoke_endpoint(EndpointName="asdf", Body="qwer") + + + +- [ ] invoke_endpoint_async + diff --git a/moto/__init__.py b/moto/__init__.py index 928ae28cf..cb1527a08 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -157,6 +157,9 @@ mock_route53resolver = lazy_load( mock_s3 = lazy_load(".s3", "mock_s3") mock_s3control = lazy_load(".s3control", "mock_s3control") mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker") +mock_sagemakerruntime = lazy_load( + ".sagemakerruntime", "mock_sagemakerruntime", boto3_name="sagemaker-runtime" +) mock_scheduler = lazy_load(".scheduler", "mock_scheduler") mock_sdb = lazy_load(".sdb", "mock_sdb") mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager") diff --git a/moto/backend_index.py b/moto/backend_index.py index 129e7bd7a..fdc603b74 100644 --- a/moto/backend_index.py +++ b/moto/backend_index.py @@ -1,4 +1,4 @@ -# autogenerated by /Users/plussier/dev/github/moto/scripts/update_backend_index.py +# autogenerated by /home/bblommers/Software/Code/bblommers/moto/scripts/update_backend_index.py import re backend_url_patterns = [ @@ -153,6 +153,10 @@ backend_url_patterns = [ re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"), ), ("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")), + ( + "sagemaker-runtime", + re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"), + ), ("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")), ("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")), ("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")), diff --git a/moto/moto_api/_internal/models.py b/moto/moto_api/_internal/models.py index ab9d02914..af75c287b 100644 --- a/moto/moto_api/_internal/models.py +++ b/moto/moto_api/_internal/models.py @@ -46,6 +46,20 @@ class MotoAPIBackend(BaseBackend): results = QueryResults(rows=rows, column_info=column_info) backend.query_results_queue.append(results) + def set_sagemaker_result( + self, + body: str, + content_type: str, + prod_variant: str, + custom_attrs: str, + account_id: str, + region: str, + ) -> None: + from moto.sagemakerruntime.models import sagemakerruntime_backends + + backend = sagemakerruntime_backends[account_id][region] + backend.results_queue.append((body, content_type, prod_variant, custom_attrs)) + def set_rds_data_result( self, records: Optional[List[List[Dict[str, Any]]]], diff --git a/moto/moto_api/_internal/responses.py b/moto/moto_api/_internal/responses.py index d95c1c011..e4c885942 100644 --- a/moto/moto_api/_internal/responses.py +++ b/moto/moto_api/_internal/responses.py @@ -168,6 +168,35 @@ class MotoAPIResponse(BaseResponse): ) return 201, {}, "" + def set_sagemaker_result( + self, + request: Any, + full_url: str, # pylint: disable=unused-argument + headers: Any, + ) -> TYPE_RESPONSE: + from .models import moto_api_backend + + request_body_size = int(headers["Content-Length"]) + body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8") + body = json.loads(body) + account_id = body.get("account_id", DEFAULT_ACCOUNT_ID) + region = body.get("region", "us-east-1") + + for result in body.get("results", []): + body = result["Body"] + content_type = result.get("ContentType") + prod_variant = result.get("InvokedProductionVariant") + custom_attrs = result.get("CustomAttributes") + moto_api_backend.set_sagemaker_result( + body=body, + content_type=content_type, + prod_variant=prod_variant, + custom_attrs=custom_attrs, + account_id=account_id, + region=region, + ) + return 201, {}, "" + def set_rds_data_result( self, request: Any, diff --git a/moto/moto_api/_internal/urls.py b/moto/moto_api/_internal/urls.py index 41506174c..6b9977355 100644 --- a/moto/moto_api/_internal/urls.py +++ b/moto/moto_api/_internal/urls.py @@ -13,6 +13,7 @@ url_paths = { "{0}/moto-api/reset-auth": response_instance.reset_auth_response, "{0}/moto-api/seed": response_instance.seed, "{0}/moto-api/static/athena/query-results": response_instance.set_athena_result, + "{0}/moto-api/static/sagemaker/endpoint-results": response_instance.set_sagemaker_result, "{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result, "{0}/moto-api/state-manager/get-transition": response_instance.get_transition, "{0}/moto-api/state-manager/set-transition": response_instance.set_transition, diff --git a/moto/moto_server/werkzeug_app.py b/moto/moto_server/werkzeug_app.py index 5dea96f0c..9de05a1d1 100644 --- a/moto/moto_server/werkzeug_app.py +++ b/moto/moto_server/werkzeug_app.py @@ -153,7 +153,10 @@ class DomainDispatcherApplication: else: host = "dynamodb" elif service == "sagemaker": - host = f"api.{service}.{region}.amazonaws.com" + if environ["PATH_INFO"].endswith("invocations"): + host = f"runtime.{service}.{region}.amazonaws.com" + else: + host = f"api.{service}.{region}.amazonaws.com" elif service == "timestream": host = f"ingest.{service}.{region}.amazonaws.com" elif service == "s3" and ( diff --git a/moto/sagemakerruntime/__init__.py b/moto/sagemakerruntime/__init__.py new file mode 100644 index 000000000..8cdc930a6 --- /dev/null +++ b/moto/sagemakerruntime/__init__.py @@ -0,0 +1,5 @@ +"""sagemakerruntime module initialization; sets value for base decorator.""" +from .models import sagemakerruntime_backends +from ..core.models import base_decorator + +mock_sagemakerruntime = base_decorator(sagemakerruntime_backends) diff --git a/moto/sagemakerruntime/models.py b/moto/sagemakerruntime/models.py new file mode 100644 index 000000000..c28d3931f --- /dev/null +++ b/moto/sagemakerruntime/models.py @@ -0,0 +1,65 @@ +from moto.core import BaseBackend, BackendDict +from typing import Dict, List, Tuple + + +class SageMakerRuntimeBackend(BaseBackend): + """Implementation of SageMakerRuntime APIs.""" + + def __init__(self, region_name: str, account_id: str): + super().__init__(region_name, account_id) + self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {} + self.results_queue: List[Tuple[str, str, str, str]] = [] + + def invoke_endpoint( + self, endpoint_name: str, unique_repr: bytes + ) -> Tuple[str, str, str, str]: + """ + This call will return static data by default. + + You can use a dedicated API to override this, by configuring a queue of expected results. + + A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty. + + Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this: + + .. sourcecode:: python + + expected_results = { + "account_id": "123456789012", # This is the default - can be omitted + "region": "us-east-1", # This is the default - can be omitted + "results": [ + { + "Body": "first body", + "ContentType": "text/xml", + "InvokedProductionVariant": "prod", + "CustomAttributes": "my_attr", + }, + # other results as required + ], + } + requests.post( + "http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results", + json=expected_results, + ) + + client = boto3.client("sagemaker", region_name="us-east-1") + details = client.invoke_endpoint(EndpointName="asdf", Body="qwer") + + """ + if endpoint_name not in self.results: + self.results[endpoint_name] = {} + if unique_repr in self.results[endpoint_name]: + return self.results[endpoint_name][unique_repr] + if self.results_queue: + self.results[endpoint_name][unique_repr] = self.results_queue.pop(0) + else: + self.results[endpoint_name][unique_repr] = ( + "body", + "content_type", + "invoked_production_variant", + "custom_attributes", + ) + return self.results[endpoint_name][unique_repr] + + +sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime") diff --git a/moto/sagemakerruntime/responses.py b/moto/sagemakerruntime/responses.py new file mode 100644 index 000000000..24cc48633 --- /dev/null +++ b/moto/sagemakerruntime/responses.py @@ -0,0 +1,44 @@ +import base64 +import json + +from moto.core.common_types import TYPE_RESPONSE +from moto.core.responses import BaseResponse +from .models import sagemakerruntime_backends, SageMakerRuntimeBackend + + +class SageMakerRuntimeResponse(BaseResponse): + """Handler for SageMakerRuntime requests and responses.""" + + def __init__(self) -> None: + super().__init__(service_name="sagemaker-runtime") + + @property + def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend: + """Return backend instance specific for this region.""" + return sagemakerruntime_backends[self.current_account][self.region] + + def invoke_endpoint(self) -> TYPE_RESPONSE: + params = self._get_params() + unique_repr = { + key: value + for key, value in self.headers.items() + if key.lower().startswith("x-amzn-sagemaker") + } + unique_repr["Accept"] = self.headers.get("Accept") + unique_repr["Body"] = self.body + endpoint_name = params.get("EndpointName") + ( + body, + content_type, + invoked_production_variant, + custom_attributes, + ) = self.sagemakerruntime_backend.invoke_endpoint( + endpoint_name=endpoint_name, # type: ignore[arg-type] + unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")), + ) + headers = {"Content-Type": content_type} + if invoked_production_variant: + headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant + if custom_attributes: + headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes + return 200, headers, body diff --git a/moto/sagemakerruntime/urls.py b/moto/sagemakerruntime/urls.py new file mode 100644 index 000000000..01c753437 --- /dev/null +++ b/moto/sagemakerruntime/urls.py @@ -0,0 +1,14 @@ +"""sagemakerruntime base URL and path.""" +from .responses import SageMakerRuntimeResponse + +url_bases = [ + r"https?://runtime\.sagemaker\.(.+)\.amazonaws\.com", +] + + +response = SageMakerRuntimeResponse() + + +url_paths = { + "{0}/endpoints/(?P[^/]+)/invocations$": response.dispatch, +} diff --git a/tests/test_sagemakerruntime/__init__.py b/tests/test_sagemakerruntime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_sagemakerruntime/test_sagemakerruntime.py b/tests/test_sagemakerruntime/test_sagemakerruntime.py new file mode 100644 index 000000000..60b256d5a --- /dev/null +++ b/tests/test_sagemakerruntime/test_sagemakerruntime.py @@ -0,0 +1,56 @@ +import boto3 +import requests + +from moto import mock_sagemakerruntime, settings + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_sagemakerruntime +def test_invoke_endpoint__default_results(): + client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1") + body = client.invoke_endpoint( + EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" + ) + + assert body["Body"].read() == b"body" + assert body["CustomAttributes"] == "custom_attributes" + + +@mock_sagemakerruntime +def test_invoke_endpoint(): + client = boto3.client("sagemaker-runtime", region_name="us-east-1") + base_url = ( + "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" + ) + + sagemaker_result = { + "results": [ + { + "Body": "first body", + "ContentType": "text/xml", + "InvokedProductionVariant": "prod", + "CustomAttributes": "my_attr", + }, + {"Body": "second body"}, + ] + } + requests.post( + f"http://{base_url}/moto-api/static/sagemaker/endpoint-results", + json=sagemaker_result, + ) + + # Return the first item from the list + body = client.invoke_endpoint(EndpointName="asdf", Body="qwer") + assert body["Body"].read() == b"first body" + + # Same input -> same output + body = client.invoke_endpoint(EndpointName="asdf", Body="qwer") + assert body["Body"].read() == b"first body" + + # Different input -> second item + body = client.invoke_endpoint( + EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm" + ) + assert body["Body"].read() == b"second body"