diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md
index a6a66e7ce..55921b3cf 100644
--- a/IMPLEMENTATION_COVERAGE.md
+++ b/IMPLEMENTATION_COVERAGE.md
@@ -6437,6 +6437,14 @@
- [ ] update_workteam
+## sagemaker-runtime
+
+50% implemented
+
+- [X] invoke_endpoint
+- [ ] invoke_endpoint_async
+
+
## scheduler
100% implemented
@@ -7482,7 +7490,6 @@
- sagemaker-featurestore-runtime
- sagemaker-geospatial
- sagemaker-metrics
-- sagemaker-runtime
- savingsplans
- schemas
- securityhub
diff --git a/docs/docs/services/sagemaker-runtime.rst b/docs/docs/services/sagemaker-runtime.rst
new file mode 100644
index 000000000..eff4d9028
--- /dev/null
+++ b/docs/docs/services/sagemaker-runtime.rst
@@ -0,0 +1,66 @@
+.. _implementedservice_sagemaker-runtime:
+
+.. |start-h3| raw:: html
+
+
+
+.. |end-h3| raw:: html
+
+
+
+=================
+sagemaker-runtime
+=================
+
+.. autoclass:: moto.sagemakerruntime.models.SageMakerRuntimeBackend
+
+|start-h3| Example usage |end-h3|
+
+.. sourcecode:: python
+
+ @mock_sagemakerruntime
+ def test_sagemakerruntime_behaviour:
+ boto3.client("sagemaker-runtime")
+ ...
+
+
+
+|start-h3| Implemented features for this service |end-h3|
+
+- [X] invoke_endpoint
+
+ This call will return static data by default.
+
+ You can use a dedicated API to override this, by configuring a queue of expected results.
+
+ A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
+
+ Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
+
+ .. sourcecode:: python
+
+ expected_results = {
+ "account_id": "123456789012", # This is the default - can be omitted
+ "region": "us-east-1", # This is the default - can be omitted
+ "results": [
+ {
+ "Body": "first body",
+ "ContentType": "text/xml",
+ "InvokedProductionVariant": "prod",
+ "CustomAttributes": "my_attr",
+ },
+ # other results as required
+ ],
+ }
+ requests.post(
+ "http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
+ json=expected_results,
+ )
+
+ client = boto3.client("sagemaker", region_name="us-east-1")
+ details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
+
+
+
+- [ ] invoke_endpoint_async
+
diff --git a/moto/__init__.py b/moto/__init__.py
index 928ae28cf..cb1527a08 100644
--- a/moto/__init__.py
+++ b/moto/__init__.py
@@ -157,6 +157,9 @@ mock_route53resolver = lazy_load(
mock_s3 = lazy_load(".s3", "mock_s3")
mock_s3control = lazy_load(".s3control", "mock_s3control")
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
+mock_sagemakerruntime = lazy_load(
+ ".sagemakerruntime", "mock_sagemakerruntime", boto3_name="sagemaker-runtime"
+)
mock_scheduler = lazy_load(".scheduler", "mock_scheduler")
mock_sdb = lazy_load(".sdb", "mock_sdb")
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
diff --git a/moto/backend_index.py b/moto/backend_index.py
index 129e7bd7a..fdc603b74 100644
--- a/moto/backend_index.py
+++ b/moto/backend_index.py
@@ -1,4 +1,4 @@
-# autogenerated by /Users/plussier/dev/github/moto/scripts/update_backend_index.py
+# autogenerated by /home/bblommers/Software/Code/bblommers/moto/scripts/update_backend_index.py
import re
backend_url_patterns = [
@@ -153,6 +153,10 @@ backend_url_patterns = [
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
),
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
+ (
+ "sagemaker-runtime",
+ re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"),
+ ),
("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")),
("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")),
("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")),
diff --git a/moto/moto_api/_internal/models.py b/moto/moto_api/_internal/models.py
index ab9d02914..af75c287b 100644
--- a/moto/moto_api/_internal/models.py
+++ b/moto/moto_api/_internal/models.py
@@ -46,6 +46,20 @@ class MotoAPIBackend(BaseBackend):
results = QueryResults(rows=rows, column_info=column_info)
backend.query_results_queue.append(results)
+ def set_sagemaker_result(
+ self,
+ body: str,
+ content_type: str,
+ prod_variant: str,
+ custom_attrs: str,
+ account_id: str,
+ region: str,
+ ) -> None:
+ from moto.sagemakerruntime.models import sagemakerruntime_backends
+
+ backend = sagemakerruntime_backends[account_id][region]
+ backend.results_queue.append((body, content_type, prod_variant, custom_attrs))
+
def set_rds_data_result(
self,
records: Optional[List[List[Dict[str, Any]]]],
diff --git a/moto/moto_api/_internal/responses.py b/moto/moto_api/_internal/responses.py
index d95c1c011..e4c885942 100644
--- a/moto/moto_api/_internal/responses.py
+++ b/moto/moto_api/_internal/responses.py
@@ -168,6 +168,35 @@ class MotoAPIResponse(BaseResponse):
)
return 201, {}, ""
+ def set_sagemaker_result(
+ self,
+ request: Any,
+ full_url: str, # pylint: disable=unused-argument
+ headers: Any,
+ ) -> TYPE_RESPONSE:
+ from .models import moto_api_backend
+
+ request_body_size = int(headers["Content-Length"])
+ body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8")
+ body = json.loads(body)
+ account_id = body.get("account_id", DEFAULT_ACCOUNT_ID)
+ region = body.get("region", "us-east-1")
+
+ for result in body.get("results", []):
+ body = result["Body"]
+ content_type = result.get("ContentType")
+ prod_variant = result.get("InvokedProductionVariant")
+ custom_attrs = result.get("CustomAttributes")
+ moto_api_backend.set_sagemaker_result(
+ body=body,
+ content_type=content_type,
+ prod_variant=prod_variant,
+ custom_attrs=custom_attrs,
+ account_id=account_id,
+ region=region,
+ )
+ return 201, {}, ""
+
def set_rds_data_result(
self,
request: Any,
diff --git a/moto/moto_api/_internal/urls.py b/moto/moto_api/_internal/urls.py
index 41506174c..6b9977355 100644
--- a/moto/moto_api/_internal/urls.py
+++ b/moto/moto_api/_internal/urls.py
@@ -13,6 +13,7 @@ url_paths = {
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
"{0}/moto-api/seed": response_instance.seed,
"{0}/moto-api/static/athena/query-results": response_instance.set_athena_result,
+ "{0}/moto-api/static/sagemaker/endpoint-results": response_instance.set_sagemaker_result,
"{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result,
"{0}/moto-api/state-manager/get-transition": response_instance.get_transition,
"{0}/moto-api/state-manager/set-transition": response_instance.set_transition,
diff --git a/moto/moto_server/werkzeug_app.py b/moto/moto_server/werkzeug_app.py
index 5dea96f0c..9de05a1d1 100644
--- a/moto/moto_server/werkzeug_app.py
+++ b/moto/moto_server/werkzeug_app.py
@@ -153,7 +153,10 @@ class DomainDispatcherApplication:
else:
host = "dynamodb"
elif service == "sagemaker":
- host = f"api.{service}.{region}.amazonaws.com"
+ if environ["PATH_INFO"].endswith("invocations"):
+ host = f"runtime.{service}.{region}.amazonaws.com"
+ else:
+ host = f"api.{service}.{region}.amazonaws.com"
elif service == "timestream":
host = f"ingest.{service}.{region}.amazonaws.com"
elif service == "s3" and (
diff --git a/moto/sagemakerruntime/__init__.py b/moto/sagemakerruntime/__init__.py
new file mode 100644
index 000000000..8cdc930a6
--- /dev/null
+++ b/moto/sagemakerruntime/__init__.py
@@ -0,0 +1,5 @@
+"""sagemakerruntime module initialization; sets value for base decorator."""
+from .models import sagemakerruntime_backends
+from ..core.models import base_decorator
+
+mock_sagemakerruntime = base_decorator(sagemakerruntime_backends)
diff --git a/moto/sagemakerruntime/models.py b/moto/sagemakerruntime/models.py
new file mode 100644
index 000000000..c28d3931f
--- /dev/null
+++ b/moto/sagemakerruntime/models.py
@@ -0,0 +1,65 @@
+from moto.core import BaseBackend, BackendDict
+from typing import Dict, List, Tuple
+
+
+class SageMakerRuntimeBackend(BaseBackend):
+ """Implementation of SageMakerRuntime APIs."""
+
+ def __init__(self, region_name: str, account_id: str):
+ super().__init__(region_name, account_id)
+ self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
+ self.results_queue: List[Tuple[str, str, str, str]] = []
+
+ def invoke_endpoint(
+ self, endpoint_name: str, unique_repr: bytes
+ ) -> Tuple[str, str, str, str]:
+ """
+ This call will return static data by default.
+
+ You can use a dedicated API to override this, by configuring a queue of expected results.
+
+ A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
+
+ Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
+
+ .. sourcecode:: python
+
+ expected_results = {
+ "account_id": "123456789012", # This is the default - can be omitted
+ "region": "us-east-1", # This is the default - can be omitted
+ "results": [
+ {
+ "Body": "first body",
+ "ContentType": "text/xml",
+ "InvokedProductionVariant": "prod",
+ "CustomAttributes": "my_attr",
+ },
+ # other results as required
+ ],
+ }
+ requests.post(
+ "http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
+ json=expected_results,
+ )
+
+ client = boto3.client("sagemaker", region_name="us-east-1")
+ details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
+
+ """
+ if endpoint_name not in self.results:
+ self.results[endpoint_name] = {}
+ if unique_repr in self.results[endpoint_name]:
+ return self.results[endpoint_name][unique_repr]
+ if self.results_queue:
+ self.results[endpoint_name][unique_repr] = self.results_queue.pop(0)
+ else:
+ self.results[endpoint_name][unique_repr] = (
+ "body",
+ "content_type",
+ "invoked_production_variant",
+ "custom_attributes",
+ )
+ return self.results[endpoint_name][unique_repr]
+
+
+sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
diff --git a/moto/sagemakerruntime/responses.py b/moto/sagemakerruntime/responses.py
new file mode 100644
index 000000000..24cc48633
--- /dev/null
+++ b/moto/sagemakerruntime/responses.py
@@ -0,0 +1,44 @@
+import base64
+import json
+
+from moto.core.common_types import TYPE_RESPONSE
+from moto.core.responses import BaseResponse
+from .models import sagemakerruntime_backends, SageMakerRuntimeBackend
+
+
+class SageMakerRuntimeResponse(BaseResponse):
+ """Handler for SageMakerRuntime requests and responses."""
+
+ def __init__(self) -> None:
+ super().__init__(service_name="sagemaker-runtime")
+
+ @property
+ def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend:
+ """Return backend instance specific for this region."""
+ return sagemakerruntime_backends[self.current_account][self.region]
+
+ def invoke_endpoint(self) -> TYPE_RESPONSE:
+ params = self._get_params()
+ unique_repr = {
+ key: value
+ for key, value in self.headers.items()
+ if key.lower().startswith("x-amzn-sagemaker")
+ }
+ unique_repr["Accept"] = self.headers.get("Accept")
+ unique_repr["Body"] = self.body
+ endpoint_name = params.get("EndpointName")
+ (
+ body,
+ content_type,
+ invoked_production_variant,
+ custom_attributes,
+ ) = self.sagemakerruntime_backend.invoke_endpoint(
+ endpoint_name=endpoint_name, # type: ignore[arg-type]
+ unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")),
+ )
+ headers = {"Content-Type": content_type}
+ if invoked_production_variant:
+ headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant
+ if custom_attributes:
+ headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
+ return 200, headers, body
diff --git a/moto/sagemakerruntime/urls.py b/moto/sagemakerruntime/urls.py
new file mode 100644
index 000000000..01c753437
--- /dev/null
+++ b/moto/sagemakerruntime/urls.py
@@ -0,0 +1,14 @@
+"""sagemakerruntime base URL and path."""
+from .responses import SageMakerRuntimeResponse
+
+url_bases = [
+ r"https?://runtime\.sagemaker\.(.+)\.amazonaws\.com",
+]
+
+
+response = SageMakerRuntimeResponse()
+
+
+url_paths = {
+ "{0}/endpoints/(?P[^/]+)/invocations$": response.dispatch,
+}
diff --git a/tests/test_sagemakerruntime/__init__.py b/tests/test_sagemakerruntime/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_sagemakerruntime/test_sagemakerruntime.py b/tests/test_sagemakerruntime/test_sagemakerruntime.py
new file mode 100644
index 000000000..60b256d5a
--- /dev/null
+++ b/tests/test_sagemakerruntime/test_sagemakerruntime.py
@@ -0,0 +1,56 @@
+import boto3
+import requests
+
+from moto import mock_sagemakerruntime, settings
+
+# See our Development Tips on writing tests for hints on how to write good tests:
+# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
+
+
+@mock_sagemakerruntime
+def test_invoke_endpoint__default_results():
+ client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
+ body = client.invoke_endpoint(
+ EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
+ )
+
+ assert body["Body"].read() == b"body"
+ assert body["CustomAttributes"] == "custom_attributes"
+
+
+@mock_sagemakerruntime
+def test_invoke_endpoint():
+ client = boto3.client("sagemaker-runtime", region_name="us-east-1")
+ base_url = (
+ "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
+ )
+
+ sagemaker_result = {
+ "results": [
+ {
+ "Body": "first body",
+ "ContentType": "text/xml",
+ "InvokedProductionVariant": "prod",
+ "CustomAttributes": "my_attr",
+ },
+ {"Body": "second body"},
+ ]
+ }
+ requests.post(
+ f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
+ json=sagemaker_result,
+ )
+
+ # Return the first item from the list
+ body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
+ assert body["Body"].read() == b"first body"
+
+ # Same input -> same output
+ body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
+ assert body["Body"].read() == b"first body"
+
+ # Different input -> second item
+ body = client.invoke_endpoint(
+ EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
+ )
+ assert body["Body"].read() == b"second body"