diff --git a/moto/sns/models.py b/moto/sns/models.py index 2bde642ac..58f93735a 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -584,10 +584,13 @@ class SNSBackend(BaseBackend): def create_platform_endpoint( self, region, application, custom_user_data, token, attributes ): - if any( - token == endpoint.token for endpoint in self.platform_endpoints.values() - ): - raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) + for endpoint in self.platform_endpoints.values(): + if token == endpoint.token: + if attributes["Enabled"].lower() == endpoint.attributes["Enabled"]: + return endpoint + raise DuplicateSnsEndpointError( + "Duplicate endpoint token with different attributes: %s" % token + ) platform_endpoint = PlatformEndpoint( region, application, custom_user_data, token, attributes ) diff --git a/tests/test_sns/test_application_boto3.py b/tests/test_sns/test_application_boto3.py index 2720c3eab..9cd3ca73e 100644 --- a/tests/test_sns/test_application_boto3.py +++ b/tests/test_sns/test_application_boto3.py @@ -156,10 +156,37 @@ def test_create_duplicate_platform_endpoint(): PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={"Enabled": "false"}, + Attributes={"Enabled": "true"}, ).should.throw(ClientError) +@mock_sns +def test_create_duplicate_platform_endpoint_with_same_attributes(): + conn = boto3.client("sns", region_name="us-east-1") + platform_application = conn.create_platform_application( + Name="my-application", Platform="APNS", Attributes={} + ) + application_arn = platform_application["PlatformApplicationArn"] + + created_endpoint = conn.create_platform_endpoint( + PlatformApplicationArn=application_arn, + Token="some_unique_id", + CustomUserData="some user data", + Attributes={"Enabled": "false"}, + ) + created_endpoint_arn = created_endpoint["EndpointArn"] + + endpoint = conn.create_platform_endpoint( + PlatformApplicationArn=application_arn, + Token="some_unique_id", + CustomUserData="some user data", + Attributes={"Enabled": "false"}, + ) + endpoint_arn = endpoint["EndpointArn"] + + endpoint_arn.should.equal(created_endpoint_arn) + + @mock_sns def test_get_list_endpoints_by_platform_application(): conn = boto3.client("sns", region_name="us-east-1")