Feature: Sagemaker Runtime (#6747)
This commit is contained in:
parent
7098388ee4
commit
5dd649378c
@ -6437,6 +6437,14 @@
|
|||||||
- [ ] update_workteam
|
- [ ] update_workteam
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## sagemaker-runtime
|
||||||
|
<details>
|
||||||
|
<summary>50% implemented</summary>
|
||||||
|
|
||||||
|
- [X] invoke_endpoint
|
||||||
|
- [ ] invoke_endpoint_async
|
||||||
|
</details>
|
||||||
|
|
||||||
## scheduler
|
## scheduler
|
||||||
<details>
|
<details>
|
||||||
<summary>100% implemented</summary>
|
<summary>100% implemented</summary>
|
||||||
@ -7482,7 +7490,6 @@
|
|||||||
- sagemaker-featurestore-runtime
|
- sagemaker-featurestore-runtime
|
||||||
- sagemaker-geospatial
|
- sagemaker-geospatial
|
||||||
- sagemaker-metrics
|
- sagemaker-metrics
|
||||||
- sagemaker-runtime
|
|
||||||
- savingsplans
|
- savingsplans
|
||||||
- schemas
|
- schemas
|
||||||
- securityhub
|
- securityhub
|
||||||
|
66
docs/docs/services/sagemaker-runtime.rst
Normal file
66
docs/docs/services/sagemaker-runtime.rst
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
.. _implementedservice_sagemaker-runtime:
|
||||||
|
|
||||||
|
.. |start-h3| raw:: html
|
||||||
|
|
||||||
|
<h3>
|
||||||
|
|
||||||
|
.. |end-h3| raw:: html
|
||||||
|
|
||||||
|
</h3>
|
||||||
|
|
||||||
|
=================
|
||||||
|
sagemaker-runtime
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. autoclass:: moto.sagemakerruntime.models.SageMakerRuntimeBackend
|
||||||
|
|
||||||
|
|start-h3| Example usage |end-h3|
|
||||||
|
|
||||||
|
.. sourcecode:: python
|
||||||
|
|
||||||
|
@mock_sagemakerruntime
|
||||||
|
def test_sagemakerruntime_behaviour:
|
||||||
|
boto3.client("sagemaker-runtime")
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|start-h3| Implemented features for this service |end-h3|
|
||||||
|
|
||||||
|
- [X] invoke_endpoint
|
||||||
|
|
||||||
|
This call will return static data by default.
|
||||||
|
|
||||||
|
You can use a dedicated API to override this, by configuring a queue of expected results.
|
||||||
|
|
||||||
|
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
|
||||||
|
|
||||||
|
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
|
||||||
|
|
||||||
|
.. sourcecode:: python
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
"account_id": "123456789012", # This is the default - can be omitted
|
||||||
|
"region": "us-east-1", # This is the default - can be omitted
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"Body": "first body",
|
||||||
|
"ContentType": "text/xml",
|
||||||
|
"InvokedProductionVariant": "prod",
|
||||||
|
"CustomAttributes": "my_attr",
|
||||||
|
},
|
||||||
|
# other results as required
|
||||||
|
],
|
||||||
|
}
|
||||||
|
requests.post(
|
||||||
|
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
|
||||||
|
json=expected_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- [ ] invoke_endpoint_async
|
||||||
|
|
@ -157,6 +157,9 @@ mock_route53resolver = lazy_load(
|
|||||||
mock_s3 = lazy_load(".s3", "mock_s3")
|
mock_s3 = lazy_load(".s3", "mock_s3")
|
||||||
mock_s3control = lazy_load(".s3control", "mock_s3control")
|
mock_s3control = lazy_load(".s3control", "mock_s3control")
|
||||||
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
|
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
|
||||||
|
mock_sagemakerruntime = lazy_load(
|
||||||
|
".sagemakerruntime", "mock_sagemakerruntime", boto3_name="sagemaker-runtime"
|
||||||
|
)
|
||||||
mock_scheduler = lazy_load(".scheduler", "mock_scheduler")
|
mock_scheduler = lazy_load(".scheduler", "mock_scheduler")
|
||||||
mock_sdb = lazy_load(".sdb", "mock_sdb")
|
mock_sdb = lazy_load(".sdb", "mock_sdb")
|
||||||
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
|
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# autogenerated by /Users/plussier/dev/github/moto/scripts/update_backend_index.py
|
# autogenerated by /home/bblommers/Software/Code/bblommers/moto/scripts/update_backend_index.py
|
||||||
import re
|
import re
|
||||||
|
|
||||||
backend_url_patterns = [
|
backend_url_patterns = [
|
||||||
@ -153,6 +153,10 @@ backend_url_patterns = [
|
|||||||
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
|
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
|
||||||
),
|
),
|
||||||
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
|
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
|
||||||
|
(
|
||||||
|
"sagemaker-runtime",
|
||||||
|
re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"),
|
||||||
|
),
|
||||||
("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")),
|
("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")),
|
||||||
("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")),
|
("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")),
|
||||||
("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")),
|
("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")),
|
||||||
|
@ -46,6 +46,20 @@ class MotoAPIBackend(BaseBackend):
|
|||||||
results = QueryResults(rows=rows, column_info=column_info)
|
results = QueryResults(rows=rows, column_info=column_info)
|
||||||
backend.query_results_queue.append(results)
|
backend.query_results_queue.append(results)
|
||||||
|
|
||||||
|
def set_sagemaker_result(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
content_type: str,
|
||||||
|
prod_variant: str,
|
||||||
|
custom_attrs: str,
|
||||||
|
account_id: str,
|
||||||
|
region: str,
|
||||||
|
) -> None:
|
||||||
|
from moto.sagemakerruntime.models import sagemakerruntime_backends
|
||||||
|
|
||||||
|
backend = sagemakerruntime_backends[account_id][region]
|
||||||
|
backend.results_queue.append((body, content_type, prod_variant, custom_attrs))
|
||||||
|
|
||||||
def set_rds_data_result(
|
def set_rds_data_result(
|
||||||
self,
|
self,
|
||||||
records: Optional[List[List[Dict[str, Any]]]],
|
records: Optional[List[List[Dict[str, Any]]]],
|
||||||
|
@ -168,6 +168,35 @@ class MotoAPIResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return 201, {}, ""
|
return 201, {}, ""
|
||||||
|
|
||||||
|
def set_sagemaker_result(
|
||||||
|
self,
|
||||||
|
request: Any,
|
||||||
|
full_url: str, # pylint: disable=unused-argument
|
||||||
|
headers: Any,
|
||||||
|
) -> TYPE_RESPONSE:
|
||||||
|
from .models import moto_api_backend
|
||||||
|
|
||||||
|
request_body_size = int(headers["Content-Length"])
|
||||||
|
body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8")
|
||||||
|
body = json.loads(body)
|
||||||
|
account_id = body.get("account_id", DEFAULT_ACCOUNT_ID)
|
||||||
|
region = body.get("region", "us-east-1")
|
||||||
|
|
||||||
|
for result in body.get("results", []):
|
||||||
|
body = result["Body"]
|
||||||
|
content_type = result.get("ContentType")
|
||||||
|
prod_variant = result.get("InvokedProductionVariant")
|
||||||
|
custom_attrs = result.get("CustomAttributes")
|
||||||
|
moto_api_backend.set_sagemaker_result(
|
||||||
|
body=body,
|
||||||
|
content_type=content_type,
|
||||||
|
prod_variant=prod_variant,
|
||||||
|
custom_attrs=custom_attrs,
|
||||||
|
account_id=account_id,
|
||||||
|
region=region,
|
||||||
|
)
|
||||||
|
return 201, {}, ""
|
||||||
|
|
||||||
def set_rds_data_result(
|
def set_rds_data_result(
|
||||||
self,
|
self,
|
||||||
request: Any,
|
request: Any,
|
||||||
|
@ -13,6 +13,7 @@ url_paths = {
|
|||||||
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
|
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
|
||||||
"{0}/moto-api/seed": response_instance.seed,
|
"{0}/moto-api/seed": response_instance.seed,
|
||||||
"{0}/moto-api/static/athena/query-results": response_instance.set_athena_result,
|
"{0}/moto-api/static/athena/query-results": response_instance.set_athena_result,
|
||||||
|
"{0}/moto-api/static/sagemaker/endpoint-results": response_instance.set_sagemaker_result,
|
||||||
"{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result,
|
"{0}/moto-api/static/rds-data/statement-results": response_instance.set_rds_data_result,
|
||||||
"{0}/moto-api/state-manager/get-transition": response_instance.get_transition,
|
"{0}/moto-api/state-manager/get-transition": response_instance.get_transition,
|
||||||
"{0}/moto-api/state-manager/set-transition": response_instance.set_transition,
|
"{0}/moto-api/state-manager/set-transition": response_instance.set_transition,
|
||||||
|
@ -153,7 +153,10 @@ class DomainDispatcherApplication:
|
|||||||
else:
|
else:
|
||||||
host = "dynamodb"
|
host = "dynamodb"
|
||||||
elif service == "sagemaker":
|
elif service == "sagemaker":
|
||||||
host = f"api.{service}.{region}.amazonaws.com"
|
if environ["PATH_INFO"].endswith("invocations"):
|
||||||
|
host = f"runtime.{service}.{region}.amazonaws.com"
|
||||||
|
else:
|
||||||
|
host = f"api.{service}.{region}.amazonaws.com"
|
||||||
elif service == "timestream":
|
elif service == "timestream":
|
||||||
host = f"ingest.{service}.{region}.amazonaws.com"
|
host = f"ingest.{service}.{region}.amazonaws.com"
|
||||||
elif service == "s3" and (
|
elif service == "s3" and (
|
||||||
|
5
moto/sagemakerruntime/__init__.py
Normal file
5
moto/sagemakerruntime/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""sagemakerruntime module initialization; sets value for base decorator."""
|
||||||
|
from .models import sagemakerruntime_backends
|
||||||
|
from ..core.models import base_decorator
|
||||||
|
|
||||||
|
mock_sagemakerruntime = base_decorator(sagemakerruntime_backends)
|
65
moto/sagemakerruntime/models.py
Normal file
65
moto/sagemakerruntime/models.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from moto.core import BaseBackend, BackendDict
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerRuntimeBackend(BaseBackend):
|
||||||
|
"""Implementation of SageMakerRuntime APIs."""
|
||||||
|
|
||||||
|
def __init__(self, region_name: str, account_id: str):
|
||||||
|
super().__init__(region_name, account_id)
|
||||||
|
self.results: Dict[str, Dict[bytes, Tuple[str, str, str, str]]] = {}
|
||||||
|
self.results_queue: List[Tuple[str, str, str, str]] = []
|
||||||
|
|
||||||
|
def invoke_endpoint(
|
||||||
|
self, endpoint_name: str, unique_repr: bytes
|
||||||
|
) -> Tuple[str, str, str, str]:
|
||||||
|
"""
|
||||||
|
This call will return static data by default.
|
||||||
|
|
||||||
|
You can use a dedicated API to override this, by configuring a queue of expected results.
|
||||||
|
|
||||||
|
A request to `get_query_results` will take the first result from that queue. Subsequent requests using the same details will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return static data if the queue is empty.
|
||||||
|
|
||||||
|
Configuring this queue by making an HTTP request to `/moto-api/static/sagemaker/endpoint-results`. An example invocation looks like this:
|
||||||
|
|
||||||
|
.. sourcecode:: python
|
||||||
|
|
||||||
|
expected_results = {
|
||||||
|
"account_id": "123456789012", # This is the default - can be omitted
|
||||||
|
"region": "us-east-1", # This is the default - can be omitted
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"Body": "first body",
|
||||||
|
"ContentType": "text/xml",
|
||||||
|
"InvokedProductionVariant": "prod",
|
||||||
|
"CustomAttributes": "my_attr",
|
||||||
|
},
|
||||||
|
# other results as required
|
||||||
|
],
|
||||||
|
}
|
||||||
|
requests.post(
|
||||||
|
"http://motoapi.amazonaws.com:5000/moto-api/static/sagemaker/endpoint-results",
|
||||||
|
json=expected_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
details = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if endpoint_name not in self.results:
|
||||||
|
self.results[endpoint_name] = {}
|
||||||
|
if unique_repr in self.results[endpoint_name]:
|
||||||
|
return self.results[endpoint_name][unique_repr]
|
||||||
|
if self.results_queue:
|
||||||
|
self.results[endpoint_name][unique_repr] = self.results_queue.pop(0)
|
||||||
|
else:
|
||||||
|
self.results[endpoint_name][unique_repr] = (
|
||||||
|
"body",
|
||||||
|
"content_type",
|
||||||
|
"invoked_production_variant",
|
||||||
|
"custom_attributes",
|
||||||
|
)
|
||||||
|
return self.results[endpoint_name][unique_repr]
|
||||||
|
|
||||||
|
|
||||||
|
sagemakerruntime_backends = BackendDict(SageMakerRuntimeBackend, "sagemaker-runtime")
|
44
moto/sagemakerruntime/responses.py
Normal file
44
moto/sagemakerruntime/responses.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
|
from moto.core.common_types import TYPE_RESPONSE
|
||||||
|
from moto.core.responses import BaseResponse
|
||||||
|
from .models import sagemakerruntime_backends, SageMakerRuntimeBackend
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerRuntimeResponse(BaseResponse):
|
||||||
|
"""Handler for SageMakerRuntime requests and responses."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(service_name="sagemaker-runtime")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend:
|
||||||
|
"""Return backend instance specific for this region."""
|
||||||
|
return sagemakerruntime_backends[self.current_account][self.region]
|
||||||
|
|
||||||
|
def invoke_endpoint(self) -> TYPE_RESPONSE:
|
||||||
|
params = self._get_params()
|
||||||
|
unique_repr = {
|
||||||
|
key: value
|
||||||
|
for key, value in self.headers.items()
|
||||||
|
if key.lower().startswith("x-amzn-sagemaker")
|
||||||
|
}
|
||||||
|
unique_repr["Accept"] = self.headers.get("Accept")
|
||||||
|
unique_repr["Body"] = self.body
|
||||||
|
endpoint_name = params.get("EndpointName")
|
||||||
|
(
|
||||||
|
body,
|
||||||
|
content_type,
|
||||||
|
invoked_production_variant,
|
||||||
|
custom_attributes,
|
||||||
|
) = self.sagemakerruntime_backend.invoke_endpoint(
|
||||||
|
endpoint_name=endpoint_name, # type: ignore[arg-type]
|
||||||
|
unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")),
|
||||||
|
)
|
||||||
|
headers = {"Content-Type": content_type}
|
||||||
|
if invoked_production_variant:
|
||||||
|
headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant
|
||||||
|
if custom_attributes:
|
||||||
|
headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
|
||||||
|
return 200, headers, body
|
14
moto/sagemakerruntime/urls.py
Normal file
14
moto/sagemakerruntime/urls.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""sagemakerruntime base URL and path."""
|
||||||
|
from .responses import SageMakerRuntimeResponse
|
||||||
|
|
||||||
|
url_bases = [
|
||||||
|
r"https?://runtime\.sagemaker\.(.+)\.amazonaws\.com",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
response = SageMakerRuntimeResponse()
|
||||||
|
|
||||||
|
|
||||||
|
url_paths = {
|
||||||
|
"{0}/endpoints/(?P<name>[^/]+)/invocations$": response.dispatch,
|
||||||
|
}
|
0
tests/test_sagemakerruntime/__init__.py
Normal file
0
tests/test_sagemakerruntime/__init__.py
Normal file
56
tests/test_sagemakerruntime/test_sagemakerruntime.py
Normal file
56
tests/test_sagemakerruntime/test_sagemakerruntime.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import boto3
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from moto import mock_sagemakerruntime, settings
|
||||||
|
|
||||||
|
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||||
|
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemakerruntime
|
||||||
|
def test_invoke_endpoint__default_results():
|
||||||
|
client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
|
||||||
|
body = client.invoke_endpoint(
|
||||||
|
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert body["Body"].read() == b"body"
|
||||||
|
assert body["CustomAttributes"] == "custom_attributes"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemakerruntime
|
||||||
|
def test_invoke_endpoint():
|
||||||
|
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
|
||||||
|
base_url = (
|
||||||
|
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
sagemaker_result = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"Body": "first body",
|
||||||
|
"ContentType": "text/xml",
|
||||||
|
"InvokedProductionVariant": "prod",
|
||||||
|
"CustomAttributes": "my_attr",
|
||||||
|
},
|
||||||
|
{"Body": "second body"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
requests.post(
|
||||||
|
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
|
||||||
|
json=sagemaker_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the first item from the list
|
||||||
|
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||||
|
assert body["Body"].read() == b"first body"
|
||||||
|
|
||||||
|
# Same input -> same output
|
||||||
|
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
|
||||||
|
assert body["Body"].read() == b"first body"
|
||||||
|
|
||||||
|
# Different input -> second item
|
||||||
|
body = client.invoke_endpoint(
|
||||||
|
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
|
||||||
|
)
|
||||||
|
assert body["Body"].read() == b"second body"
|
Loading…
x
Reference in New Issue
Block a user