Feature: Sagemaker Runtime (#6747)

This commit is contained in:
Bert Blommers 2023-09-01 07:06:50 +00:00 committed by GitHub
parent 7098388ee4
commit 5dd649378c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 314 additions and 3 deletions

View File

@ -6437,6 +6437,14 @@
- [ ] update_workteam
</details>
## sagemaker-runtime
<details>
<summary>50% implemented</summary>
- [X] invoke_endpoint
- [ ] invoke_endpoint_async
</details>
## scheduler
<details>
<summary>100% implemented</summary>
@ -7482,7 +7490,6 @@
- sagemaker-featurestore-runtime
- sagemaker-geospatial
- sagemaker-metrics
- sagemaker-runtime
- savingsplans
- schemas
- securityhub

View File

@ -0,0 +1,66 @@
.. _implementedservice_sagemaker-runtime:
.. |start-h3| raw:: html
<h3>
.. |end-h3| raw:: html
</h3>
=================
sagemaker-runtime
=================
.. autoclass:: moto.sagemakerruntime.models.SageMakerRuntimeBackend
|start-h3| Example usage |end-h3|
.. sourcecode:: python
@mock_sagemakerruntime
def test_sagemakerruntime_behaviour:
boto3.client("sagemaker-runtime")
...
|start-h3| Implemented features for this service |end-h3|
- [X] invoke_endpoint
This call will return static data by default.
You can use a dedicated API to override this, by configuring a queue of expected results.
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
.. sourcecode:: python
expected_results = {
"account_id": "123456789012", # This is the default - can be omitted
"region": "us-east-1", # This is the default - can be omitted
"results": [
{
"Body": "first body",
"ContentType": "text/xml",
"InvokedProductionVariant": "prod",
"CustomAttributes": "my_attr",
},
# other results as required
],
}
requests.post(
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
json=expected_results,
)
client = boto3.client("sagemaker", region_name="us-east-1")
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
- [ ] invoke_endpoint_async

View File

@ -157,6 +157,9 @@ mock_route53resolver = lazy_load(
mock_s3 = lazy_load(".s3", "mock_s3")
mock_s3control = lazy_load(".s3control", "mock_s3control")
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
mock_sagemakerruntime = lazy_load(
".sagemakerruntime", "mock_sagemakerruntime", boto3_name="sagemaker-runtime"
)
mock_scheduler = lazy_load(".scheduler", "mock_scheduler")
mock_sdb = lazy_load(".sdb", "mock_sdb")
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")

View File

@ -1,4 +1,4 @@
# autogenerated by /Users/plussier/dev/github/moto/scripts/update_backend_index.py
# autogenerated by /home/bblommers/Software/Code/bblommers/moto/scripts/update_backend_index.py
import re
backend_url_patterns = [
@ -153,6 +153,10 @@ backend_url_patterns = [
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
),
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
(
"sagemaker-runtime",
re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"),
),
("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")),
("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")),
("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")),

View File

@ -46,6 +46,20 @@ class MotoAPIBackend(BaseBackend):
results = QueryResults(rows=rows, column_info=column_info)
backend.query_results_queue.append(results)
def set_sagemaker_result(
self,
body: str,
content_type: str,
prod_variant: str,
custom_attrs: str,
account_id: str,
region: str,
) -> None:
from moto.sagemakerruntime.models import sagemakerruntime_backends
backend = sagemakerruntime_backends[account_id][region]
backend.results_queue.append((body, content_type, prod_variant, custom_attrs))
def set_rds_data_result(
self,
records: Optional[List[List[Dict[str, Any]]]],

View File

@ -168,6 +168,35 @@ class MotoAPIResponse(BaseResponse):
)
return 201, {}, ""
def set_sagemaker_result(
self,
request: Any,
full_url: str, # pylint: disable=unused-argument
headers: Any,
) -> TYPE_RESPONSE:
from .models import moto_api_backend
request_body_size = int(headers["Content-Length"])
body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8")
body = json.loads(body)
account_id = body.get("account_id", DEFAULT_ACCOUNT_ID)
region = body.get("region", "us-east-1")
for result in body.get("results", []):
body = result["Body"]
content_type = result.get("ContentType")
prod_variant = result.get("InvokedProductionVariant")
custom_attrs = result.get("CustomAttributes")
moto_api_backend.set_sagemaker_result(
body=body,
content_type=content_type,
prod_variant=prod_variant,
custom_attrs=custom_attrs,
account_id=account_id,
region=region,
)
return 201, {}, ""
def set_rds_data_result(
self,
request: Any,

View File

@ -13,6 +13,7 @@ url_paths = {
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
"{0}/moto-api/seed": response_instance.seed,
"{0}/moto-api/static/athena/query-results": response_instance.set_athena_result,
"{0}/moto-api/static/sagemaker/endpoint-results": response_instance.set_sagemaker_result,
"{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result,
"{0}/moto-api/state-manager/get-transition": response_instance.get_transition,
"{0}/moto-api/state-manager/set-transition": response_instance.set_transition,

View File

@ -153,7 +153,10 @@ class DomainDispatcherApplication:
else:
host = "dynamodb"
elif service == "sagemaker":
host = f"api.{service}.{region}.amazonaws.com"
if environ["PATH_INFO"].endswith("invocations"):
host = f"runtime.{service}.{region}.amazonaws.com"
else:
host = f"api.{service}.{region}.amazonaws.com"
elif service == "timestream":
host = f"ingest.{service}.{region}.amazonaws.com"
elif service == "s3" and (

View File

@ -0,0 +1,5 @@
"""sagemakerruntime module initialization; sets value for base decorator."""
from .models import sagemakerruntime_backends
from ..core.models import base_decorator
mock_sagemakerruntime = base_decorator(sagemakerruntime_backends)

View File

@ -0,0 +1,65 @@
from moto.core import BaseBackend, BackendDict
from typing import Dict, List, Tuple
class SageMakerRuntimeBackend(BaseBackend):
"""Implementation of SageMakerRuntime APIs."""
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
self.results_queue: List[Tuple[str, str, str, str]] = []
def invoke_endpoint(
self, endpoint_name: str, unique_repr: bytes
) -> Tuple[str, str, str, str]:
"""
This call will return static data by default.
You can use a dedicated API to override this, by configuring a queue of expected results.
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
.. sourcecode:: python
expected_results = {
"account_id": "123456789012", # This is the default - can be omitted
"region": "us-east-1", # This is the default - can be omitted
"results": [
{
"Body": "first body",
"ContentType": "text/xml",
"InvokedProductionVariant": "prod",
"CustomAttributes": "my_attr",
},
# other results as required
],
}
requests.post(
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
json=expected_results,
)
client = boto3.client("sagemaker", region_name="us-east-1")
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
"""
if endpoint_name not in self.results:
self.results[endpoint_name] = {}
if unique_repr in self.results[endpoint_name]:
return self.results[endpoint_name][unique_repr]
if self.results_queue:
self.results[endpoint_name][unique_repr] = self.results_queue.pop(0)
else:
self.results[endpoint_name][unique_repr] = (
"body",
"content_type",
"invoked_production_variant",
"custom_attributes",
)
return self.results[endpoint_name][unique_repr]
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")

View File

@ -0,0 +1,44 @@
import base64
import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import sagemakerruntime_backends, SageMakerRuntimeBackend
class SageMakerRuntimeResponse(BaseResponse):
"""Handler for SageMakerRuntime requests and responses."""
def __init__(self) -> None:
super().__init__(service_name="sagemaker-runtime")
@property
def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend:
"""Return backend instance specific for this region."""
return sagemakerruntime_backends[self.current_account][self.region]
def invoke_endpoint(self) -> TYPE_RESPONSE:
params = self._get_params()
unique_repr = {
key: value
for key, value in self.headers.items()
if key.lower().startswith("x-amzn-sagemaker")
}
unique_repr["Accept"] = self.headers.get("Accept")
unique_repr["Body"] = self.body
endpoint_name = params.get("EndpointName")
(
body,
content_type,
invoked_production_variant,
custom_attributes,
) = self.sagemakerruntime_backend.invoke_endpoint(
endpoint_name=endpoint_name, # type: ignore[arg-type]
unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")),
)
headers = {"Content-Type": content_type}
if invoked_production_variant:
headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant
if custom_attributes:
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
return 200, headers, body

View File

@ -0,0 +1,14 @@
"""sagemakerruntime base URL and path."""
from .responses import SageMakerRuntimeResponse
url_bases = [
r"https?://runtime\.sagemaker\.(.+)\.amazonaws\.com",
]
response = SageMakerRuntimeResponse()
url_paths = {
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
}

View File

View File

@ -0,0 +1,56 @@
import boto3
import requests
from moto import mock_sagemakerruntime, settings
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@mock_sagemakerruntime
def test_invoke_endpoint__default_results():
client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
body = client.invoke_endpoint(
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
)
assert body["Body"].read() == b"body"
assert body["CustomAttributes"] == "custom_attributes"
@mock_sagemakerruntime
def test_invoke_endpoint():
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
base_url = (
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
)
sagemaker_result = {
"results": [
{
"Body": "first body",
"ContentType": "text/xml",
"InvokedProductionVariant": "prod",
"CustomAttributes": "my_attr",
},
{"Body": "second body"},
]
}
requests.post(
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
json=sagemaker_result,
)
# Return the first item from the list
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
assert body["Body"].read() == b"first body"
# Same input -> same output
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
assert body["Body"].read() == b"first body"
# Different input -> second item
body = client.invoke_endpoint(
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
)
assert body["Body"].read() == b"second body"