SNS: Improve PlatformApplication error handling (#6953)

This commit is contained in:
Bert Blommers 2023-10-26 19:48:18 +00:00 committed by GitHub
parent 9c0724e0d0
commit 8fb053fb05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 116 deletions

View File

@ -31,7 +31,7 @@ class DuplicateSnsEndpointError(SNSException):
code = 400 code = 400
def __init__(self, message: str): def __init__(self, message: str):
super().__init__("DuplicateEndpoint", message) super().__init__("InvalidParameter", message)
class SnsEndpointDisabled(SNSException): class SnsEndpointDisabled(SNSException):

View File

@ -364,7 +364,7 @@ class PlatformEndpoint(BaseModel):
def publish(self, message: str) -> str: def publish(self, message: str) -> str:
if not self.enabled: if not self.enabled:
raise SnsEndpointDisabled(f"Endpoint {self.id} disabled") raise SnsEndpointDisabled("Endpoint is disabled")
# This is where we would actually send a message # This is where we would actually send a message
message_id = str(mock_random.uuid4()) message_id = str(mock_random.uuid4())
@ -646,7 +646,7 @@ class SNSBackend(BaseBackend):
try: try:
return self.applications[arn] return self.applications[arn]
except KeyError: except KeyError:
raise SNSNotFoundError(f"Application with arn {arn} not found") raise SNSNotFoundError("PlatformApplication does not exist")
def set_platform_application_attributes( def set_platform_application_attributes(
self, arn: str, attributes: Dict[str, Any] self, arn: str, attributes: Dict[str, Any]
@ -673,13 +673,16 @@ class SNSBackend(BaseBackend):
) -> PlatformEndpoint: ) -> PlatformEndpoint:
for endpoint in self.platform_endpoints.values(): for endpoint in self.platform_endpoints.values():
if token == endpoint.token: if token == endpoint.token:
if ( same_user_data = custom_user_data == endpoint.custom_user_data
same_attrs = (
attributes.get("Enabled", "").lower() attributes.get("Enabled", "").lower()
== endpoint.attributes["Enabled"] == endpoint.attributes["Enabled"]
): )
if same_user_data and same_attrs:
return endpoint return endpoint
raise DuplicateSnsEndpointError( raise DuplicateSnsEndpointError(
f"Duplicate endpoint token with different attributes: {token}" f"Invalid parameter: Token Reason: Endpoint {endpoint.arn} already exists with the same Token, but different attributes."
) )
platform_endpoint = PlatformEndpoint( platform_endpoint = PlatformEndpoint(
self.account_id, self.account_id,

View File

@ -558,6 +558,7 @@ class SNSResponse(BaseResponse):
def list_endpoints_by_platform_application(self) -> str: def list_endpoints_by_platform_application(self) -> str:
application_arn = self._get_param("PlatformApplicationArn") application_arn = self._get_param("PlatformApplicationArn")
self.backend.get_application(application_arn)
endpoints = self.backend.list_endpoints_by_platform_application(application_arn) endpoints = self.backend.list_endpoints_by_platform_application(application_arn)
if self.request_json: if self.request_json:

View File

@ -1 +1,48 @@
# This file is intentionally left blank. import boto3
import botocore
import os
from functools import wraps
from moto import mock_sns, mock_sts
from unittest import SkipTest
def sns_aws_verified(func):
"""
Function that is verified to work against AWS.
Can be run against AWS at any time by setting:
MOTO_TEST_ALLOW_AWS_REQUEST=true
If this environment variable is not set, the function runs in a `mock_ses` context.
"""
@wraps(func)
def pagination_wrapper():
allow_aws_request = (
os.environ.get("MOTO_TEST_ALLOW_AWS_REQUEST", "false").lower() == "true"
)
if allow_aws_request:
ssm = boto3.client("ssm", "us-east-1")
try:
param = ssm.get_parameter(
Name="/moto/tests/ses/firebase_api_key", WithDecryption=True
)
api_key = param["Parameter"]["Value"]
resp = func(api_key)
except botocore.exceptions.ClientError:
# SNS tests try to create a PlatformApplication that connects to GCM
# (Google Cloud Messaging, also known as Firebase Messaging)
# That requires an account and an API key
# If the API key has not been configured in SSM, we'll just skip the test
#
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sns/client/create_platform_application.html
# AWS calls it 'API key', but Firebase calls it Server Key
#
# https://stackoverflow.com/a/75896532/13245310
raise SkipTest("Can't execute SNS tests without Firebase API key")
else:
with mock_sns(), mock_sts():
resp = func("mock_api_key")
return resp
return pagination_wrapper

View File

@ -1,10 +1,13 @@
import boto3 import boto3
from botocore.exceptions import ClientError
import pytest import pytest
from botocore.exceptions import ClientError
from uuid import uuid4
from moto import mock_sns from moto import mock_sns
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from . import sns_aws_verified
@mock_sns @mock_sns
def test_create_platform_application(): def test_create_platform_application():
@ -133,80 +136,146 @@ def test_create_platform_endpoint():
) )
@mock_sns @pytest.mark.aws_verified
def test_create_duplicate_platform_endpoint(): @sns_aws_verified
def test_create_duplicate_platform_endpoint(api_key=None):
identity = boto3.client("sts", region_name="us-east-1").get_caller_identity()
account_id = identity["Account"]
conn = boto3.client("sns", region_name="us-east-1") conn = boto3.client("sns", region_name="us-east-1")
platform_application = conn.create_platform_application( platform_name = str(uuid4())[0:6]
Name="my-application", Platform="APNS", Attributes={} application_arn = None
) try:
application_arn = platform_application["PlatformApplicationArn"] platform_application = conn.create_platform_application(
Name=platform_name,
conn.create_platform_endpoint( Platform="GCM",
PlatformApplicationArn=application_arn, Attributes={"PlatformCredential": api_key},
Token="some_unique_id", )
CustomUserData="some user data", application_arn = platform_application["PlatformApplicationArn"]
Attributes={"Enabled": "false"}, assert (
) application_arn
== f"arn:aws:sns:us-east-1:{account_id}:app/GCM/{platform_name}"
with pytest.raises(ClientError):
conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={"Enabled": "true"},
) )
# WITHOUT custom user data
resp = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_without_userdata",
Attributes={"Enabled": "false"},
)
endpoint1_arn = resp["EndpointArn"]
@mock_sns # This operation is idempotent
def test_create_duplicate_platform_endpoint_with_same_attributes(): resp = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_without_userdata",
Attributes={"Enabled": "false"},
)
assert resp["EndpointArn"] == endpoint1_arn
# Same token, but with CustomUserData
with pytest.raises(ClientError) as exc:
conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_without_userdata",
CustomUserData="some user data",
Attributes={"Enabled": "false"},
)
err = exc.value.response["Error"]
assert err["Code"] == "InvalidParameter"
assert (
err["Message"]
== f"Invalid parameter: Token Reason: Endpoint {endpoint1_arn} already exists with the same Token, but different attributes."
)
# WITH custom user data
resp = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_with_userdata",
CustomUserData="some user data",
Attributes={"Enabled": "false"},
)
endpoint2_arn = resp["EndpointArn"]
assert endpoint2_arn.startswith(
f"arn:aws:sns:us-east-1:{account_id}:endpoint/GCM/{platform_name}/"
)
# Still idempotent
resp = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_with_userdata",
CustomUserData="some user data",
Attributes={"Enabled": "false"},
)
assert resp["EndpointArn"] == endpoint2_arn
# Creating a platform endpoint with different attributes fails
with pytest.raises(ClientError) as exc:
conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="token_with_userdata",
CustomUserData="some user data",
Attributes={"Enabled": "true"},
)
err = exc.value.response["Error"]
assert err["Code"] == "InvalidParameter"
assert (
err["Message"]
== f"Invalid parameter: Token Reason: Endpoint {endpoint2_arn} already exists with the same Token, but different attributes."
)
with pytest.raises(ClientError) as exc:
conn.create_platform_endpoint(
PlatformApplicationArn=application_arn + "2",
Token="unknown_arn",
)
err = exc.value.response["Error"]
assert err["Code"] == "NotFound"
assert err["Message"] == "PlatformApplication does not exist"
finally:
if application_arn is not None:
conn.delete_platform_application(PlatformApplicationArn=application_arn)
@pytest.mark.aws_verified
@sns_aws_verified
def test_get_list_endpoints_by_platform_application(api_key=None):
conn = boto3.client("sns", region_name="us-east-1") conn = boto3.client("sns", region_name="us-east-1")
platform_application = conn.create_platform_application( platform_name = str(uuid4())[0:6]
Name="my-application", Platform="APNS", Attributes={} application_arn = None
) try:
application_arn = platform_application["PlatformApplicationArn"] platform_application = conn.create_platform_application(
Name=platform_name,
Platform="GCM",
Attributes={"PlatformCredential": api_key},
)
application_arn = platform_application["PlatformApplicationArn"]
created_endpoint = conn.create_platform_endpoint( endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn, PlatformApplicationArn=application_arn,
Token="some_unique_id", Token="some_unique_id",
CustomUserData="some user data", Attributes={"CustomUserData": "some data"},
Attributes={"Enabled": "false"}, )
) endpoint_arn = endpoint["EndpointArn"]
created_endpoint_arn = created_endpoint["EndpointArn"]
endpoint = conn.create_platform_endpoint( endpoint_list = conn.list_endpoints_by_platform_application(
PlatformApplicationArn=application_arn, PlatformApplicationArn=application_arn
Token="some_unique_id", )["Endpoints"]
CustomUserData="some user data",
Attributes={"Enabled": "false"},
)
endpoint_arn = endpoint["EndpointArn"]
assert endpoint_arn == created_endpoint_arn assert len(endpoint_list) == 1
assert endpoint_list[0]["Attributes"]["CustomUserData"] == "some data"
assert endpoint_list[0]["EndpointArn"] == endpoint_arn
finally:
if application_arn is not None:
conn.delete_platform_application(PlatformApplicationArn=application_arn)
with pytest.raises(ClientError) as exc:
@mock_sns conn.list_endpoints_by_platform_application(
def test_get_list_endpoints_by_platform_application(): PlatformApplicationArn=application_arn + "2"
conn = boto3.client("sns", region_name="us-east-1") )
platform_application = conn.create_platform_application( err = exc.value.response["Error"]
Name="my-application", Platform="APNS", Attributes={} assert err["Code"] == "NotFound"
) assert err["Message"] == "PlatformApplication does not exist"
application_arn = platform_application["PlatformApplicationArn"]
endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={"CustomUserData": "some data"},
)
endpoint_arn = endpoint["EndpointArn"]
endpoint_list = conn.list_endpoints_by_platform_application(
PlatformApplicationArn=application_arn
)["Endpoints"]
assert len(endpoint_list) == 1
assert endpoint_list[0]["Attributes"]["CustomUserData"] == "some data"
assert endpoint_list[0]["EndpointArn"] == endpoint_arn
@mock_sns @mock_sns
@ -231,13 +300,16 @@ def test_get_endpoint_attributes():
) )
@mock_sns @pytest.mark.aws_verified
def test_get_non_existent_endpoint_attributes(): @sns_aws_verified
def test_get_non_existent_endpoint_attributes(
api_key=None,
): # pylint: disable=unused-argument
identity = boto3.client("sts", region_name="us-east-1").get_caller_identity()
account_id = identity["Account"]
conn = boto3.client("sns", region_name="us-east-1") conn = boto3.client("sns", region_name="us-east-1")
endpoint_arn = ( endpoint_arn = f"arn:aws:sns:us-east-1:{account_id}:endpoint/APNS/my-application/c1f76c42-192a-4e75-b04f-a9268ce2abf3"
"arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application"
"/c1f76c42-192a-4e75-b04f-a9268ce2abf3"
)
with pytest.raises(conn.exceptions.NotFoundException) as excinfo: with pytest.raises(conn.exceptions.NotFoundException) as excinfo:
conn.get_endpoint_attributes(EndpointArn=endpoint_arn) conn.get_endpoint_attributes(EndpointArn=endpoint_arn)
error = excinfo.value.response["Error"] error = excinfo.value.response["Error"]
@ -284,33 +356,25 @@ def test_delete_endpoint():
platform_application = conn.create_platform_application( platform_application = conn.create_platform_application(
Name="my-application", Platform="APNS", Attributes={} Name="my-application", Platform="APNS", Attributes={}
) )
application_arn = platform_application["PlatformApplicationArn"] app_arn = platform_application["PlatformApplicationArn"]
endpoint = conn.create_platform_endpoint( endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn, PlatformApplicationArn=app_arn,
Token="some_unique_id", Token="some_unique_id",
CustomUserData="some user data", CustomUserData="some user data",
Attributes={"Enabled": "true"}, Attributes={"Enabled": "true"},
) )
assert ( endpoints = conn.list_endpoints_by_platform_application(
len( PlatformApplicationArn=app_arn
conn.list_endpoints_by_platform_application( )["Endpoints"]
PlatformApplicationArn=application_arn assert len(endpoints) == 1
)["Endpoints"]
)
== 1
)
conn.delete_endpoint(EndpointArn=endpoint["EndpointArn"]) conn.delete_endpoint(EndpointArn=endpoint["EndpointArn"])
assert ( endpoints = conn.list_endpoints_by_platform_application(
len( PlatformApplicationArn=app_arn
conn.list_endpoints_by_platform_application( )["Endpoints"]
PlatformApplicationArn=application_arn assert len(endpoints) == 0
)["Endpoints"]
)
== 0
)
@mock_sns @mock_sns
@ -335,27 +399,35 @@ def test_publish_to_platform_endpoint():
) )
@mock_sns @pytest.mark.aws_verified
def test_publish_to_disabled_platform_endpoint(): @sns_aws_verified
def test_publish_to_disabled_platform_endpoint(api_key=None):
conn = boto3.client("sns", region_name="us-east-1") conn = boto3.client("sns", region_name="us-east-1")
platform_application = conn.create_platform_application( platform_name = str(uuid4())[0:6]
Name="my-application", Platform="APNS", Attributes={} application_arn = None
) try:
application_arn = platform_application["PlatformApplicationArn"] platform_application = conn.create_platform_application(
Name=platform_name,
endpoint = conn.create_platform_endpoint( Platform="GCM",
PlatformApplicationArn=application_arn, Attributes={"PlatformCredential": api_key},
Token="some_unique_id",
CustomUserData="some user data",
Attributes={"Enabled": "false"},
)
endpoint_arn = endpoint["EndpointArn"]
with pytest.raises(ClientError):
conn.publish(
Message="some message", MessageStructure="json", TargetArn=endpoint_arn
) )
application_arn = platform_application["PlatformApplicationArn"]
endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
Attributes={"Enabled": "false"},
)
endpoint_arn = endpoint["EndpointArn"]
with pytest.raises(ClientError) as exc:
conn.publish(Message="msg", MessageStructure="json", TargetArn=endpoint_arn)
err = exc.value.response["Error"]
assert err["Code"] == "EndpointDisabled"
assert err["Message"] == "Endpoint is disabled"
finally:
if application_arn is not None:
conn.delete_platform_application(PlatformApplicationArn=application_arn)
@mock_sns @mock_sns