Feature: Sagemaker Runtime (#6747)
This commit is contained in:
parent
7098388ee4
commit
5dd649378c
@ -6437,6 +6437,14 @@
|
||||
- [ ] update_workteam
|
||||
</details>
|
||||
|
||||
## sagemaker-runtime
|
||||
<details>
|
||||
<summary>50% implemented</summary>
|
||||
|
||||
- [X] invoke_endpoint
|
||||
- [ ] invoke_endpoint_async
|
||||
</details>
|
||||
|
||||
## scheduler
|
||||
<details>
|
||||
<summary>100% implemented</summary>
|
||||
@ -7482,7 +7490,6 @@
|
||||
- sagemaker-featurestore-runtime
|
||||
- sagemaker-geospatial
|
||||
- sagemaker-metrics
|
||||
- sagemaker-runtime
|
||||
- savingsplans
|
||||
- schemas
|
||||
- securityhub
|
||||
|
66
docs/docs/services/sagemaker-runtime.rst
Normal file
66
docs/docs/services/sagemaker-runtime.rst
Normal file
@ -0,0 +1,66 @@
|
||||
.. _implementedservice_sagemaker-runtime:
|
||||
|
||||
.. |start-h3| raw:: html
|
||||
|
||||
<h3>
|
||||
|
||||
.. |end-h3| raw:: html
|
||||
|
||||
</h3>
|
||||
|
||||
=================
|
||||
sagemaker-runtime
|
||||
=================
|
||||
|
||||
.. autoclass:: moto.sagemakerruntime.models.SageMakerRuntimeBackend
|
||||
|
||||
|start-h3| Example usage |end-h3|
|
||||
|
||||
.. sourcecode:: python
|
||||
|
||||
@mock_sagemakerruntime
|
||||
def test_sagemakerruntime_behaviour:
|
||||
boto3.client("sagemaker-runtime")
|
||||
...
|
||||
|
||||
|
||||
|
||||
|start-h3| Implemented features for this service |end-h3|
|
||||
|
||||
- [X] invoke_endpoint
|
||||
|
||||
This call will return static data by default.
|
||||
|
||||
You can use a dedicated API to override this, by configuring a queue of expected results.
|
||||
|
||||
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
|
||||
|
||||
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
|
||||
|
||||
.. sourcecode:: python
|
||||
|
||||
expected_results = {
|
||||
"account_id": "123456789012", # This is the default - can be omitted
|
||||
"region": "us-east-1", # This is the default - can be omitted
|
||||
"results": [
|
||||
{
|
||||
"Body": "first body",
|
||||
"ContentType": "text/xml",
|
||||
"InvokedProductionVariant": "prod",
|
||||
"CustomAttributes": "my_attr",
|
||||
},
|
||||
# other results as required
|
||||
],
|
||||
}
|
||||
requests.post(
|
||||
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
|
||||
json=expected_results,
|
||||
)
|
||||
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||
|
||||
|
||||
|
||||
- [ ] invoke_endpoint_async
|
||||
|
@ -157,6 +157,9 @@ mock_route53resolver = lazy_load(
|
||||
mock_s3 = lazy_load(".s3", "mock_s3")
|
||||
mock_s3control = lazy_load(".s3control", "mock_s3control")
|
||||
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
|
||||
mock_sagemakerruntime = lazy_load(
|
||||
".sagemakerruntime", "mock_sagemakerruntime", boto3_name="sagemaker-runtime"
|
||||
)
|
||||
mock_scheduler = lazy_load(".scheduler", "mock_scheduler")
|
||||
mock_sdb = lazy_load(".sdb", "mock_sdb")
|
||||
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# autogenerated by /Users/plussier/dev/github/moto/scripts/update_backend_index.py
|
||||
# autogenerated by /home/bblommers/Software/Code/bblommers/moto/scripts/update_backend_index.py
|
||||
import re
|
||||
|
||||
backend_url_patterns = [
|
||||
@ -153,6 +153,10 @@ backend_url_patterns = [
|
||||
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
|
||||
),
|
||||
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
|
||||
(
|
||||
"sagemaker-runtime",
|
||||
re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"),
|
||||
),
|
||||
("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")),
|
||||
("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")),
|
||||
("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")),
|
||||
|
@ -46,6 +46,20 @@ class MotoAPIBackend(BaseBackend):
|
||||
results = QueryResults(rows=rows, column_info=column_info)
|
||||
backend.query_results_queue.append(results)
|
||||
|
||||
def set_sagemaker_result(
|
||||
self,
|
||||
body: str,
|
||||
content_type: str,
|
||||
prod_variant: str,
|
||||
custom_attrs: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
) -> None:
|
||||
from moto.sagemakerruntime.models import sagemakerruntime_backends
|
||||
|
||||
backend = sagemakerruntime_backends[account_id][region]
|
||||
backend.results_queue.append((body, content_type, prod_variant, custom_attrs))
|
||||
|
||||
def set_rds_data_result(
|
||||
self,
|
||||
records: Optional[List[List[Dict[str, Any]]]],
|
||||
|
@ -168,6 +168,35 @@ class MotoAPIResponse(BaseResponse):
|
||||
)
|
||||
return 201, {}, ""
|
||||
|
||||
def set_sagemaker_result(
|
||||
self,
|
||||
request: Any,
|
||||
full_url: str, # pylint: disable=unused-argument
|
||||
headers: Any,
|
||||
) -> TYPE_RESPONSE:
|
||||
from .models import moto_api_backend
|
||||
|
||||
request_body_size = int(headers["Content-Length"])
|
||||
body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8")
|
||||
body = json.loads(body)
|
||||
account_id = body.get("account_id", DEFAULT_ACCOUNT_ID)
|
||||
region = body.get("region", "us-east-1")
|
||||
|
||||
for result in body.get("results", []):
|
||||
body = result["Body"]
|
||||
content_type = result.get("ContentType")
|
||||
prod_variant = result.get("InvokedProductionVariant")
|
||||
custom_attrs = result.get("CustomAttributes")
|
||||
moto_api_backend.set_sagemaker_result(
|
||||
body=body,
|
||||
content_type=content_type,
|
||||
prod_variant=prod_variant,
|
||||
custom_attrs=custom_attrs,
|
||||
account_id=account_id,
|
||||
region=region,
|
||||
)
|
||||
return 201, {}, ""
|
||||
|
||||
def set_rds_data_result(
|
||||
self,
|
||||
request: Any,
|
||||
|
@ -13,6 +13,7 @@ url_paths = {
|
||||
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
|
||||
"{0}/moto-api/seed": response_instance.seed,
|
||||
"{0}/moto-api/static/athena/query-results": response_instance.set_athena_result,
|
||||
"{0}/moto-api/static/sagemaker/endpoint-results": response_instance.set_sagemaker_result,
|
||||
"{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result,
|
||||
"{0}/moto-api/state-manager/get-transition": response_instance.get_transition,
|
||||
"{0}/moto-api/state-manager/set-transition": response_instance.set_transition,
|
||||
|
@ -153,7 +153,10 @@ class DomainDispatcherApplication:
|
||||
else:
|
||||
host = "dynamodb"
|
||||
elif service == "sagemaker":
|
||||
host = f"api.{service}.{region}.amazonaws.com"
|
||||
if environ["PATH_INFO"].endswith("invocations"):
|
||||
host = f"runtime.{service}.{region}.amazonaws.com"
|
||||
else:
|
||||
host = f"api.{service}.{region}.amazonaws.com"
|
||||
elif service == "timestream":
|
||||
host = f"ingest.{service}.{region}.amazonaws.com"
|
||||
elif service == "s3" and (
|
||||
|
5
moto/sagemakerruntime/__init__.py
Normal file
5
moto/sagemakerruntime/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""sagemakerruntime module initialization; sets value for base decorator."""
|
||||
from .models import sagemakerruntime_backends
|
||||
from ..core.models import base_decorator
|
||||
|
||||
mock_sagemakerruntime = base_decorator(sagemakerruntime_backends)
|
65
moto/sagemakerruntime/models.py
Normal file
65
moto/sagemakerruntime/models.py
Normal file
@ -0,0 +1,65 @@
|
||||
from moto.core import BaseBackend, BackendDict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class SageMakerRuntimeBackend(BaseBackend):
|
||||
"""Implementation of SageMakerRuntime APIs."""
|
||||
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
|
||||
self.results_queue: List[Tuple[str, str, str, str]] = []
|
||||
|
||||
def invoke_endpoint(
|
||||
self, endpoint_name: str, unique_repr: bytes
|
||||
) -> Tuple[str, str, str, str]:
|
||||
"""
|
||||
This call will return static data by default.
|
||||
|
||||
You can use a dedicated API to override this, by configuring a queue of expected results.
|
||||
|
||||
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
|
||||
|
||||
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
|
||||
|
||||
.. sourcecode:: python
|
||||
|
||||
expected_results = {
|
||||
"account_id": "123456789012", # This is the default - can be omitted
|
||||
"region": "us-east-1", # This is the default - can be omitted
|
||||
"results": [
|
||||
{
|
||||
"Body": "first body",
|
||||
"ContentType": "text/xml",
|
||||
"InvokedProductionVariant": "prod",
|
||||
"CustomAttributes": "my_attr",
|
||||
},
|
||||
# other results as required
|
||||
],
|
||||
}
|
||||
requests.post(
|
||||
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
|
||||
json=expected_results,
|
||||
)
|
||||
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||
|
||||
"""
|
||||
if endpoint_name not in self.results:
|
||||
self.results[endpoint_name] = {}
|
||||
if unique_repr in self.results[endpoint_name]:
|
||||
return self.results[endpoint_name][unique_repr]
|
||||
if self.results_queue:
|
||||
self.results[endpoint_name][unique_repr] = self.results_queue.pop(0)
|
||||
else:
|
||||
self.results[endpoint_name][unique_repr] = (
|
||||
"body",
|
||||
"content_type",
|
||||
"invoked_production_variant",
|
||||
"custom_attributes",
|
||||
)
|
||||
return self.results[endpoint_name][unique_repr]
|
||||
|
||||
|
||||
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
|
44
moto/sagemakerruntime/responses.py
Normal file
44
moto/sagemakerruntime/responses.py
Normal file
@ -0,0 +1,44 @@
|
||||
import base64
|
||||
import json
|
||||
|
||||
from moto.core.common_types import TYPE_RESPONSE
|
||||
from moto.core.responses import BaseResponse
|
||||
from .models import sagemakerruntime_backends, SageMakerRuntimeBackend
|
||||
|
||||
|
||||
class SageMakerRuntimeResponse(BaseResponse):
|
||||
"""Handler for SageMakerRuntime requests and responses."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="sagemaker-runtime")
|
||||
|
||||
@property
|
||||
def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend:
|
||||
"""Return backend instance specific for this region."""
|
||||
return sagemakerruntime_backends[self.current_account][self.region]
|
||||
|
||||
def invoke_endpoint(self) -> TYPE_RESPONSE:
|
||||
params = self._get_params()
|
||||
unique_repr = {
|
||||
key: value
|
||||
for key, value in self.headers.items()
|
||||
if key.lower().startswith("x-amzn-sagemaker")
|
||||
}
|
||||
unique_repr["Accept"] = self.headers.get("Accept")
|
||||
unique_repr["Body"] = self.body
|
||||
endpoint_name = params.get("EndpointName")
|
||||
(
|
||||
body,
|
||||
content_type,
|
||||
invoked_production_variant,
|
||||
custom_attributes,
|
||||
) = self.sagemakerruntime_backend.invoke_endpoint(
|
||||
endpoint_name=endpoint_name, # type: ignore[arg-type]
|
||||
unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")),
|
||||
)
|
||||
headers = {"Content-Type": content_type}
|
||||
if invoked_production_variant:
|
||||
headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant
|
||||
if custom_attributes:
|
||||
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
|
||||
return 200, headers, body
|
14
moto/sagemakerruntime/urls.py
Normal file
14
moto/sagemakerruntime/urls.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""sagemakerruntime base URL and path."""
|
||||
from .responses import SageMakerRuntimeResponse
|
||||
|
||||
url_bases = [
|
||||
r"https?://runtime\.sagemaker\.(.+)\.amazonaws\.com",
|
||||
]
|
||||
|
||||
|
||||
response = SageMakerRuntimeResponse()
|
||||
|
||||
|
||||
url_paths = {
|
||||
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
|
||||
}
|
0
tests/test_sagemakerruntime/__init__.py
Normal file
0
tests/test_sagemakerruntime/__init__.py
Normal file
56
tests/test_sagemakerruntime/test_sagemakerruntime.py
Normal file
56
tests/test_sagemakerruntime/test_sagemakerruntime.py
Normal file
@ -0,0 +1,56 @@
|
||||
import boto3
|
||||
import requests
|
||||
|
||||
from moto import mock_sagemakerruntime, settings
|
||||
|
||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||
|
||||
|
||||
@mock_sagemakerruntime
|
||||
def test_invoke_endpoint__default_results():
|
||||
client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
|
||||
body = client.invoke_endpoint(
|
||||
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||
)
|
||||
|
||||
assert body["Body"].read() == b"body"
|
||||
assert body["CustomAttributes"] == "custom_attributes"
|
||||
|
||||
|
||||
@mock_sagemakerruntime
|
||||
def test_invoke_endpoint():
|
||||
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
|
||||
base_url = (
|
||||
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
|
||||
)
|
||||
|
||||
sagemaker_result = {
|
||||
"results": [
|
||||
{
|
||||
"Body": "first body",
|
||||
"ContentType": "text/xml",
|
||||
"InvokedProductionVariant": "prod",
|
||||
"CustomAttributes": "my_attr",
|
||||
},
|
||||
{"Body": "second body"},
|
||||
]
|
||||
}
|
||||
requests.post(
|
||||
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
|
||||
json=sagemaker_result,
|
||||
)
|
||||
|
||||
# Return the first item from the list
|
||||
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||
assert body["Body"].read() == b"first body"
|
||||
|
||||
# Same input -> same output
|
||||
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||
assert body["Body"].read() == b"first body"
|
||||
|
||||
# Different input -> second item
|
||||
body = client.invoke_endpoint(
|
||||
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||
)
|
||||
assert body["Body"].read() == b"second body"
|
Loading…
x
Reference in New Issue
Block a user