Techdebt: Replace sure with regular assertions in RedShift (#6690)
This commit is contained in:
		
							parent
							
								
									f57bb251c8
								
							
						
					
					
						commit
						ca83236da6
					
				| @ -1,14 +1,15 @@ | |||||||
| import string |  | ||||||
| import re |  | ||||||
| import responses |  | ||||||
| import requests |  | ||||||
| import time |  | ||||||
| from datetime import datetime |  | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from openapi_spec_validator import validate_spec | from datetime import datetime | ||||||
|  | import re | ||||||
|  | import string | ||||||
|  | import time | ||||||
| from typing import Any, Dict, List, Optional, Tuple, Union | from typing import Any, Dict, List, Optional, Tuple, Union | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
| 
 | 
 | ||||||
|  | from openapi_spec_validator import validate_spec | ||||||
|  | import requests | ||||||
|  | import responses | ||||||
|  | 
 | ||||||
| try: | try: | ||||||
|     from openapi_spec_validator.validation.exceptions import OpenAPIValidationError |     from openapi_spec_validator.validation.exceptions import OpenAPIValidationError | ||||||
| except ImportError: | except ImportError: | ||||||
| @ -1511,7 +1512,7 @@ class APIGatewayBackend(BaseBackend): | |||||||
|             integrationHttpMethod="GET" |             integrationHttpMethod="GET" | ||||||
|         ) |         ) | ||||||
|         deploy_url = f"https://{api_id}.execute-api.us-east-1.amazonaws.com/dev" |         deploy_url = f"https://{api_id}.execute-api.us-east-1.amazonaws.com/dev" | ||||||
|         requests.get(deploy_url).content.should.equal(b"a fake response") |         assert requests.get(deploy_url).content == b"a fake response" | ||||||
| 
 | 
 | ||||||
|     Limitations: |     Limitations: | ||||||
|      - Integrations of type HTTP are supported |      - Integrations of type HTTP are supported | ||||||
|  | |||||||
| @ -1,12 +1,13 @@ | |||||||
| import time |  | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  | import time | ||||||
|  | from typing import Any, Dict, List, Optional | ||||||
|  | 
 | ||||||
| from moto.core import BaseBackend, BackendDict, BaseModel | from moto.core import BaseBackend, BackendDict, BaseModel | ||||||
| from moto.moto_api._internal import mock_random | from moto.moto_api._internal import mock_random | ||||||
| from moto.utilities.paginator import paginate | from moto.utilities.paginator import paginate | ||||||
| from typing import Any, Dict, List, Optional |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TaggableResourceMixin(object): | class TaggableResourceMixin: | ||||||
|     # This mixing was copied from Redshift when initially implementing |     # This mixing was copied from Redshift when initially implementing | ||||||
|     # Athena. TBD if it's worth the overhead. |     # Athena. TBD if it's worth the overhead. | ||||||
| 
 | 
 | ||||||
| @ -263,7 +264,7 @@ class AthenaBackend(BaseBackend): | |||||||
|                 "http://motoapi.amazonaws.com:5000/moto-api/static/athena/query-results", |                 "http://motoapi.amazonaws.com:5000/moto-api/static/athena/query-results", | ||||||
|                 json=expected_results, |                 json=expected_results, | ||||||
|             ) |             ) | ||||||
|             resp.status_code.should.equal(201) |             assert resp.status_code == 201 | ||||||
| 
 | 
 | ||||||
|             client = boto3.client("athena", region_name="us-east-1") |             client = boto3.client("athena", region_name="us-east-1") | ||||||
|             details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"] |             details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"] | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| from moto.core import BaseBackend, BackendDict |  | ||||||
| 
 |  | ||||||
| from typing import Any, Dict, List, Optional, Tuple | from typing import Any, Dict, List, Optional, Tuple | ||||||
| 
 | 
 | ||||||
|  | from moto.core import BaseBackend, BackendDict | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class QueryResults: | class QueryResults: | ||||||
|     def __init__( |     def __init__( | ||||||
| @ -66,7 +66,7 @@ class RDSDataServiceBackend(BaseBackend): | |||||||
|                 "http://motoapi.amazonaws.com:5000/moto-api/static/rds-data/statement-results", |                 "http://motoapi.amazonaws.com:5000/moto-api/static/rds-data/statement-results", | ||||||
|                 json=expected_results, |                 json=expected_results, | ||||||
|             ) |             ) | ||||||
|             resp.status_code.should.equal(201) |             assert resp.status_code == 201 | ||||||
| 
 | 
 | ||||||
|             rdsdata = boto3.client("rds-data", region_name="us-east-1") |             rdsdata = boto3.client("rds-data", region_name="us-east-1") | ||||||
|             resp = rdsdata.execute_statement(resourceArn="not applicable", secretArn="not applicable", sql="SELECT some FROM thing") |             resp = rdsdata.execute_statement(resourceArn="not applicable", secretArn="not applicable", sql="SELECT some FROM thing") | ||||||
|  | |||||||
| @ -1,8 +1,10 @@ | |||||||
|  | from collections import OrderedDict | ||||||
| import copy | import copy | ||||||
| import datetime | import datetime | ||||||
| 
 |  | ||||||
| from collections import OrderedDict |  | ||||||
| from typing import Any, Dict, Iterable, List, Optional | from typing import Any, Dict, Iterable, List, Optional | ||||||
|  | 
 | ||||||
|  | from dateutil.tz import tzutc | ||||||
|  | 
 | ||||||
| from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel | from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel | ||||||
| from moto.core.utils import iso_8601_datetime_with_milliseconds | from moto.core.utils import iso_8601_datetime_with_milliseconds | ||||||
| from moto.ec2 import ec2_backends | from moto.ec2 import ec2_backends | ||||||
| @ -46,7 +48,10 @@ class TaggableResourceMixin: | |||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def arn(self) -> str: |     def arn(self) -> str: | ||||||
|         return f"arn:aws:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}" |         return ( | ||||||
|  |             f"arn:aws:redshift:{self.region}:{self.account_id}" | ||||||
|  |             f":{self.resource_type}:{self.resource_id}" | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: |     def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: | ||||||
|         new_keys = [tag_set["Key"] for tag_set in tags] |         new_keys = [tag_set["Key"] for tag_set in tags] | ||||||
| @ -95,7 +100,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel): | |||||||
|         self.redshift_backend = redshift_backend |         self.redshift_backend = redshift_backend | ||||||
|         self.cluster_identifier = cluster_identifier |         self.cluster_identifier = cluster_identifier | ||||||
|         self.create_time = iso_8601_datetime_with_milliseconds( |         self.create_time = iso_8601_datetime_with_milliseconds( | ||||||
|             datetime.datetime.utcnow() |             datetime.datetime.now(tzutc()) | ||||||
|         ) |         ) | ||||||
|         self.status = "available" |         self.status = "available" | ||||||
|         self.node_type = node_type |         self.node_type = node_type | ||||||
| @ -220,7 +225,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel): | |||||||
| 
 | 
 | ||||||
|         if attribute_name == "Endpoint.Address": |         if attribute_name == "Endpoint.Address": | ||||||
|             return self.endpoint |             return self.endpoint | ||||||
|         elif attribute_name == "Endpoint.Port": |         if attribute_name == "Endpoint.Port": | ||||||
|             return self.port |             return self.port | ||||||
|         raise UnformattedGetAttTemplateException() |         raise UnformattedGetAttTemplateException() | ||||||
| 
 | 
 | ||||||
| @ -528,7 +533,7 @@ class Snapshot(TaggableResourceMixin, BaseModel): | |||||||
|         self.snapshot_type = snapshot_type |         self.snapshot_type = snapshot_type | ||||||
|         self.status = "available" |         self.status = "available" | ||||||
|         self.create_time = iso_8601_datetime_with_milliseconds( |         self.create_time = iso_8601_datetime_with_milliseconds( | ||||||
|             datetime.datetime.utcnow() |             datetime.datetime.now(tzutc()) | ||||||
|         ) |         ) | ||||||
|         self.iam_roles_arn = iam_roles_arn or [] |         self.iam_roles_arn = iam_roles_arn or [] | ||||||
| 
 | 
 | ||||||
| @ -622,8 +627,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|             } |             } | ||||||
|             cluster.cluster_snapshot_copy_status = status |             cluster.cluster_snapshot_copy_status = status | ||||||
|             return cluster |             return cluster | ||||||
|         else: |         raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) | ||||||
|             raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) |  | ||||||
| 
 | 
 | ||||||
|     def disable_snapshot_copy(self, **kwargs: Any) -> Cluster: |     def disable_snapshot_copy(self, **kwargs: Any) -> Cluster: | ||||||
|         cluster_identifier = kwargs["cluster_identifier"] |         cluster_identifier = kwargs["cluster_identifier"] | ||||||
| @ -631,8 +635,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if cluster.cluster_snapshot_copy_status is not None: |         if cluster.cluster_snapshot_copy_status is not None: | ||||||
|             cluster.cluster_snapshot_copy_status = None |             cluster.cluster_snapshot_copy_status = None | ||||||
|             return cluster |             return cluster | ||||||
|         else: |         raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) | ||||||
|             raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) |  | ||||||
| 
 | 
 | ||||||
|     def modify_snapshot_copy_retention_period( |     def modify_snapshot_copy_retention_period( | ||||||
|         self, cluster_identifier: str, retention_period: str |         self, cluster_identifier: str, retention_period: str | ||||||
| @ -650,7 +653,10 @@ class RedshiftBackend(BaseBackend): | |||||||
|             raise ClusterAlreadyExistsFaultError() |             raise ClusterAlreadyExistsFaultError() | ||||||
|         cluster = Cluster(self, **cluster_kwargs) |         cluster = Cluster(self, **cluster_kwargs) | ||||||
|         self.clusters[cluster_identifier] = cluster |         self.clusters[cluster_identifier] = cluster | ||||||
|         snapshot_id = f"rs:{cluster_identifier}-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}" |         snapshot_id = ( | ||||||
|  |             f"rs:{cluster_identifier}-" | ||||||
|  |             f"{datetime.datetime.now(tzutc()).strftime('%Y-%m-%d-%H-%M')}" | ||||||
|  |         ) | ||||||
|         # Automated snapshots don't copy over the tags |         # Automated snapshots don't copy over the tags | ||||||
|         self.create_cluster_snapshot( |         self.create_cluster_snapshot( | ||||||
|             cluster_identifier, |             cluster_identifier, | ||||||
| @ -679,8 +685,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if cluster_identifier: |         if cluster_identifier: | ||||||
|             if cluster_identifier in self.clusters: |             if cluster_identifier in self.clusters: | ||||||
|                 return [self.clusters[cluster_identifier]] |                 return [self.clusters[cluster_identifier]] | ||||||
|             else: |             raise ClusterNotFoundError(cluster_identifier) | ||||||
|                 raise ClusterNotFoundError(cluster_identifier) |  | ||||||
|         return list(self.clusters.values()) |         return list(self.clusters.values()) | ||||||
| 
 | 
 | ||||||
|     def modify_cluster(self, **cluster_kwargs: Any) -> Cluster: |     def modify_cluster(self, **cluster_kwargs: Any) -> Cluster: | ||||||
| @ -739,9 +744,10 @@ class RedshiftBackend(BaseBackend): | |||||||
|                 and cluster_snapshot_identifer is None |                 and cluster_snapshot_identifer is None | ||||||
|             ): |             ): | ||||||
|                 raise InvalidParameterCombinationError( |                 raise InvalidParameterCombinationError( | ||||||
|                     "FinalClusterSnapshotIdentifier is required unless SkipFinalClusterSnapshot is specified." |                     "FinalClusterSnapshotIdentifier is required unless " | ||||||
|  |                     "SkipFinalClusterSnapshot is specified." | ||||||
|                 ) |                 ) | ||||||
|             elif ( |             if ( | ||||||
|                 cluster_skip_final_snapshot is False |                 cluster_skip_final_snapshot is False | ||||||
|                 and cluster_snapshot_identifer is not None |                 and cluster_snapshot_identifer is not None | ||||||
|             ):  # create snapshot |             ):  # create snapshot | ||||||
| @ -781,8 +787,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if subnet_identifier: |         if subnet_identifier: | ||||||
|             if subnet_identifier in self.subnet_groups: |             if subnet_identifier in self.subnet_groups: | ||||||
|                 return [self.subnet_groups[subnet_identifier]] |                 return [self.subnet_groups[subnet_identifier]] | ||||||
|             else: |             raise ClusterSubnetGroupNotFoundError(subnet_identifier) | ||||||
|                 raise ClusterSubnetGroupNotFoundError(subnet_identifier) |  | ||||||
|         return list(self.subnet_groups.values()) |         return list(self.subnet_groups.values()) | ||||||
| 
 | 
 | ||||||
|     def delete_cluster_subnet_group(self, subnet_identifier: str) -> SubnetGroup: |     def delete_cluster_subnet_group(self, subnet_identifier: str) -> SubnetGroup: | ||||||
| @ -812,8 +817,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if security_group_name: |         if security_group_name: | ||||||
|             if security_group_name in self.security_groups: |             if security_group_name in self.security_groups: | ||||||
|                 return [self.security_groups[security_group_name]] |                 return [self.security_groups[security_group_name]] | ||||||
|             else: |             raise ClusterSecurityGroupNotFoundError(security_group_name) | ||||||
|                 raise ClusterSecurityGroupNotFoundError(security_group_name) |  | ||||||
|         return list(self.security_groups.values()) |         return list(self.security_groups.values()) | ||||||
| 
 | 
 | ||||||
|     def delete_cluster_security_group( |     def delete_cluster_security_group( | ||||||
| @ -861,8 +865,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if parameter_group_name: |         if parameter_group_name: | ||||||
|             if parameter_group_name in self.parameter_groups: |             if parameter_group_name in self.parameter_groups: | ||||||
|                 return [self.parameter_groups[parameter_group_name]] |                 return [self.parameter_groups[parameter_group_name]] | ||||||
|             else: |             raise ClusterParameterGroupNotFoundError(parameter_group_name) | ||||||
|                 raise ClusterParameterGroupNotFoundError(parameter_group_name) |  | ||||||
|         return list(self.parameter_groups.values()) |         return list(self.parameter_groups.values()) | ||||||
| 
 | 
 | ||||||
|     def delete_cluster_parameter_group( |     def delete_cluster_parameter_group( | ||||||
| @ -983,8 +986,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|         if snapshot_copy_grant_name: |         if snapshot_copy_grant_name: | ||||||
|             if snapshot_copy_grant_name in self.snapshot_copy_grants: |             if snapshot_copy_grant_name in self.snapshot_copy_grants: | ||||||
|                 return [self.snapshot_copy_grants[snapshot_copy_grant_name]] |                 return [self.snapshot_copy_grants[snapshot_copy_grant_name]] | ||||||
|             else: |             raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) | ||||||
|                 raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) |  | ||||||
|         return copy_grants |         return copy_grants | ||||||
| 
 | 
 | ||||||
|     def _get_resource_from_arn(self, arn: str) -> TaggableResourceMixin: |     def _get_resource_from_arn(self, arn: str) -> TaggableResourceMixin: | ||||||
| @ -1000,16 +1002,16 @@ class RedshiftBackend(BaseBackend): | |||||||
|         resources = self.RESOURCE_TYPE_MAP.get(resource_type) |         resources = self.RESOURCE_TYPE_MAP.get(resource_type) | ||||||
|         if resources is None: |         if resources is None: | ||||||
|             message = ( |             message = ( | ||||||
|                 f"Tagging is not supported for this type of resource: '{resource_type}' " |                 "Tagging is not supported for this type of resource: " | ||||||
|                 "(the ARN is potentially malformed, please check the ARN documentation for more information)" |                 f"'{resource_type}' (the ARN is potentially malformed, " | ||||||
|  |                 "please check the ARN documentation for more information)" | ||||||
|             ) |             ) | ||||||
|             raise ResourceNotFoundFaultError(message=message) |             raise ResourceNotFoundFaultError(message=message) | ||||||
|         try: |         try: | ||||||
|             resource = resources[resource_id] |             resource = resources[resource_id] | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             raise ResourceNotFoundFaultError(resource_type, resource_id) |             raise ResourceNotFoundFaultError(resource_type, resource_id) | ||||||
|         else: |         return resource | ||||||
|             return resource |  | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _describe_tags_for_resources(resources: Iterable[Any]) -> List[Dict[str, Any]]:  # type: ignore[misc] |     def _describe_tags_for_resources(resources: Iterable[Any]) -> List[Dict[str, Any]]:  # type: ignore[misc] | ||||||
| @ -1081,7 +1083,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|             raise InvalidParameterValueError( |             raise InvalidParameterValueError( | ||||||
|                 "Token duration must be between 900 and 3600 seconds" |                 "Token duration must be between 900 and 3600 seconds" | ||||||
|             ) |             ) | ||||||
|         expiration = datetime.datetime.utcnow() + datetime.timedelta( |         expiration = datetime.datetime.now(tzutc()) + datetime.timedelta( | ||||||
|             0, duration_seconds |             0, duration_seconds | ||||||
|         ) |         ) | ||||||
|         if cluster_identifier in self.clusters: |         if cluster_identifier in self.clusters: | ||||||
| @ -1092,8 +1094,7 @@ class RedshiftBackend(BaseBackend): | |||||||
|                 "DbPassword": mock_random.get_random_string(32), |                 "DbPassword": mock_random.get_random_string(32), | ||||||
|                 "Expiration": expiration, |                 "Expiration": expiration, | ||||||
|             } |             } | ||||||
|         else: |         raise ClusterNotFoundError(cluster_identifier) | ||||||
|             raise ClusterNotFoundError(cluster_identifier) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| redshift_backends = BackendDict(RedshiftBackend, "redshift") | redshift_backends = BackendDict(RedshiftBackend, "redshift") | ||||||
|  | |||||||
| @ -3,6 +3,5 @@ pytest | |||||||
| pytest-cov | pytest-cov | ||||||
| pytest-ordering | pytest-ordering | ||||||
| pytest-xdist | pytest-xdist | ||||||
| surer |  | ||||||
| freezegun | freezegun | ||||||
| pylint | pylint | ||||||
|  | |||||||
| @ -1,15 +1,17 @@ | |||||||
|  | from unittest import SkipTest | ||||||
|  | 
 | ||||||
| import boto3 | import boto3 | ||||||
| import pytest | import pytest | ||||||
|  | 
 | ||||||
| from moto import mock_s3 | from moto import mock_s3 | ||||||
| from moto import settings | from moto import settings | ||||||
| from unittest import SkipTest |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture(scope="function", name="aws_credentials") | @pytest.fixture(scope="function", name="aws_credentials") | ||||||
| def fixture_aws_credentials(monkeypatch): | def fixture_aws_credentials(monkeypatch): | ||||||
|  |     """Mocked AWS Credentials for moto.""" | ||||||
|     if settings.TEST_SERVER_MODE: |     if settings.TEST_SERVER_MODE: | ||||||
|         raise SkipTest("No point in testing this in ServerMode.") |         raise SkipTest("No point in testing this in ServerMode.") | ||||||
|     """Mocked AWS Credentials for moto.""" |  | ||||||
|     monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") |     monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") | ||||||
|     monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") |     monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") | ||||||
|     monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing") |     monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing") | ||||||
| @ -63,8 +65,7 @@ def test_mock_works_with_resource_created_outside( | |||||||
| 
 | 
 | ||||||
|     patch_resource(outside_resource) |     patch_resource(outside_resource) | ||||||
| 
 | 
 | ||||||
|     b = list(outside_resource.buckets.all()) |     assert not list(outside_resource.buckets.all()) | ||||||
|     assert b == [] |  | ||||||
|     m.stop() |     m.stop() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -126,7 +127,7 @@ def test_mock_works_when_replacing_client( | |||||||
|     try: |     try: | ||||||
|         logic.do_important_things() |         logic.do_important_things() | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         str(e).should.contain("InvalidAccessKeyId") |         assert str(e) == "InvalidAccessKeyId" | ||||||
| 
 | 
 | ||||||
|     client_initialized_after_mock = boto3.client("s3", region_name="us-east-1") |     client_initialized_after_mock = boto3.client("s3", region_name="us-east-1") | ||||||
|     logic._s3 = client_initialized_after_mock |     logic._s3 = client_initialized_after_mock | ||||||
|  | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,6 +1,6 @@ | |||||||
| import boto3 |  | ||||||
| import json | import json | ||||||
| import sure  # noqa # pylint: disable=unused-import | 
 | ||||||
|  | import boto3 | ||||||
| 
 | 
 | ||||||
| from moto import mock_cloudformation, mock_ec2, mock_redshift | from moto import mock_cloudformation, mock_ec2, mock_redshift | ||||||
| from tests.test_cloudformation.fixtures import redshift | from tests.test_cloudformation.fixtures import redshift | ||||||
| @ -33,20 +33,20 @@ def test_redshift_stack(): | |||||||
| 
 | 
 | ||||||
|     cluster_res = redshift_conn.describe_clusters() |     cluster_res = redshift_conn.describe_clusters() | ||||||
|     clusters = cluster_res["Clusters"] |     clusters = cluster_res["Clusters"] | ||||||
|     clusters.should.have.length_of(1) |     assert len(clusters) == 1 | ||||||
|     cluster = clusters[0] |     cluster = clusters[0] | ||||||
|     cluster["DBName"].should.equal("mydb") |     assert cluster["DBName"] == "mydb" | ||||||
|     cluster["NumberOfNodes"].should.equal(2) |     assert cluster["NumberOfNodes"] == 2 | ||||||
|     cluster["NodeType"].should.equal("dw1.xlarge") |     assert cluster["NodeType"] == "dw1.xlarge" | ||||||
|     cluster["MasterUsername"].should.equal("myuser") |     assert cluster["MasterUsername"] == "myuser" | ||||||
|     cluster["Endpoint"]["Port"].should.equal(5439) |     assert cluster["Endpoint"]["Port"] == 5439 | ||||||
|     cluster["VpcSecurityGroups"].should.have.length_of(1) |     assert len(cluster["VpcSecurityGroups"]) == 1 | ||||||
|     security_group_id = cluster["VpcSecurityGroups"][0]["VpcSecurityGroupId"] |     security_group_id = cluster["VpcSecurityGroups"][0]["VpcSecurityGroupId"] | ||||||
| 
 | 
 | ||||||
|     groups = ec2.describe_security_groups(GroupIds=[security_group_id])[ |     groups = ec2.describe_security_groups(GroupIds=[security_group_id])[ | ||||||
|         "SecurityGroups" |         "SecurityGroups" | ||||||
|     ] |     ] | ||||||
|     groups.should.have.length_of(1) |     assert len(groups) == 1 | ||||||
|     group = groups[0] |     group = groups[0] | ||||||
|     group["IpPermissions"].should.have.length_of(1) |     assert len(group["IpPermissions"]) == 1 | ||||||
|     group["IpPermissions"][0]["IpRanges"][0]["CidrIp"].should.equal("10.0.0.1/16") |     assert group["IpPermissions"][0]["IpRanges"][0]["CidrIp"] == "10.0.0.1/16" | ||||||
|  | |||||||
| @ -1,15 +1,13 @@ | |||||||
| import sure  # noqa # pylint: disable=unused-import | """Test the different server responses.""" | ||||||
| import json | import json | ||||||
|  | import re | ||||||
|  | 
 | ||||||
| import pytest | import pytest | ||||||
| import xmltodict | import xmltodict | ||||||
| 
 | 
 | ||||||
| import moto.server as server | import moto.server as server | ||||||
| from moto import mock_redshift | from moto import mock_redshift | ||||||
| 
 | 
 | ||||||
| """ |  | ||||||
| Test the different server responses |  | ||||||
| """ |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| @mock_redshift | @mock_redshift | ||||||
| def test_describe_clusters(): | def test_describe_clusters(): | ||||||
| @ -19,7 +17,7 @@ def test_describe_clusters(): | |||||||
|     res = test_client.get("/?Action=DescribeClusters") |     res = test_client.get("/?Action=DescribeClusters") | ||||||
| 
 | 
 | ||||||
|     result = res.data.decode("utf-8") |     result = res.data.decode("utf-8") | ||||||
|     result.should.contain("<Clusters></Clusters>") |     assert "<Clusters></Clusters>" in result | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @mock_redshift | @mock_redshift | ||||||
| @ -31,9 +29,9 @@ def test_describe_clusters_with_json_content_type(): | |||||||
| 
 | 
 | ||||||
|     result = json.loads(res.data.decode("utf-8")) |     result = json.loads(res.data.decode("utf-8")) | ||||||
|     del result["DescribeClustersResponse"]["ResponseMetadata"] |     del result["DescribeClustersResponse"]["ResponseMetadata"] | ||||||
|     result.should.equal( |     assert result == { | ||||||
|         {"DescribeClustersResponse": {"DescribeClustersResult": {"Clusters": []}}} |         "DescribeClustersResponse": {"DescribeClustersResult": {"Clusters": []}} | ||||||
|     ) |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -61,30 +59,29 @@ def test_create_cluster(is_json): | |||||||
|         result = xmltodict.parse(result, dict_constructor=dict) |         result = xmltodict.parse(result, dict_constructor=dict) | ||||||
| 
 | 
 | ||||||
|     del result["CreateClusterResponse"]["ResponseMetadata"] |     del result["CreateClusterResponse"]["ResponseMetadata"] | ||||||
|     result.should.have.key("CreateClusterResponse") |     assert "CreateClusterResponse" in result | ||||||
|     result["CreateClusterResponse"].should.have.key("CreateClusterResult") |     assert "CreateClusterResult" in result["CreateClusterResponse"] | ||||||
|     result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster") |     assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"] | ||||||
|     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] |     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] | ||||||
| 
 | 
 | ||||||
|     result.should.have.key("MasterUsername").equal("masteruser") |     assert result["MasterUsername"] == "masteruser" | ||||||
|     result.should.have.key("MasterUserPassword").equal("****") |     assert result["MasterUserPassword"] == "****" | ||||||
|     result.should.have.key("ClusterVersion").equal("1.0") |     assert result["ClusterVersion"] == "1.0" | ||||||
|     result.should.have.key("ClusterSubnetGroupName").equal(None) |     assert result["ClusterSubnetGroupName"] is None | ||||||
|     result.should.have.key("AvailabilityZone").equal("us-east-1a") |     assert result["AvailabilityZone"] == "us-east-1a" | ||||||
|     result.should.have.key("ClusterStatus").equal("creating") |     assert result["ClusterStatus"] == "creating" | ||||||
|     result.should.have.key("NumberOfNodes").equal(1 if is_json else "1") |     assert result["NumberOfNodes"] == (1 if is_json else "1") | ||||||
|     result.should.have.key("PubliclyAccessible").equal(None) |     assert result["PubliclyAccessible"] is None | ||||||
|     result.should.have.key("Encrypted").equal(None) |     assert result["Encrypted"] is None | ||||||
|     result.should.have.key("DBName").equal("dev") |     assert result["DBName"] == "dev" | ||||||
|     result.should.have.key("NodeType").equal("ds2.xlarge") |     assert result["NodeType"] == "ds2.xlarge" | ||||||
|     result.should.have.key("ClusterIdentifier").equal("examplecluster") |     assert result["ClusterIdentifier"] == "examplecluster" | ||||||
|     result.should.have.key("Endpoint").should.have.key("Address").match( |     assert re.match( | ||||||
|         "examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com" |         "examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com", | ||||||
|  |         result["Endpoint"]["Address"], | ||||||
|     ) |     ) | ||||||
|     result.should.have.key("Endpoint").should.have.key("Port").equal( |     assert result["Endpoint"]["Port"] == (5439 if is_json else "5439") | ||||||
|         5439 if is_json else "5439" |     assert "ClusterCreateTime" in result | ||||||
|     ) |  | ||||||
|     result.should.have.key("ClusterCreateTime") |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -122,37 +119,34 @@ def test_create_cluster_multiple_params(is_json): | |||||||
|         result = xmltodict.parse(result, dict_constructor=dict) |         result = xmltodict.parse(result, dict_constructor=dict) | ||||||
| 
 | 
 | ||||||
|     del result["CreateClusterResponse"]["ResponseMetadata"] |     del result["CreateClusterResponse"]["ResponseMetadata"] | ||||||
|     result.should.have.key("CreateClusterResponse") |     assert "CreateClusterResponse" in result | ||||||
|     result["CreateClusterResponse"].should.have.key("CreateClusterResult") |     assert "CreateClusterResult" in result["CreateClusterResponse"] | ||||||
|     result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster") |     assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"] | ||||||
|     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] |     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] | ||||||
| 
 | 
 | ||||||
|     result.should.have.key("MasterUsername").equal("masteruser") |     assert result["MasterUsername"] == "masteruser" | ||||||
|     result.should.have.key("MasterUserPassword").equal("****") |     assert result["MasterUserPassword"] == "****" | ||||||
|     result.should.have.key("ClusterVersion").equal("2.0") |     assert result["ClusterVersion"] == "2.0" | ||||||
|     result.should.have.key("ClusterSubnetGroupName").equal(None) |     assert result["ClusterSubnetGroupName"] is None | ||||||
|     result.should.have.key("AvailabilityZone").equal("us-east-1a") |     assert result["AvailabilityZone"] == "us-east-1a" | ||||||
|     result.should.have.key("ClusterStatus").equal("creating") |     assert result["ClusterStatus"] == "creating" | ||||||
|     result.should.have.key("NumberOfNodes").equal(3 if is_json else "3") |     assert result["NumberOfNodes"] == (3 if is_json else "3") | ||||||
|     result.should.have.key("PubliclyAccessible").equal(None) |     assert result["PubliclyAccessible"] is None | ||||||
|     result.should.have.key("Encrypted").equal("True") |     assert result["Encrypted"] == "True" | ||||||
|     result.should.have.key("DBName").equal("testdb") |     assert result["DBName"] == "testdb" | ||||||
|     result.should.have.key("NodeType").equal("ds2.xlarge") |     assert result["NodeType"] == "ds2.xlarge" | ||||||
|     result.should.have.key("ClusterIdentifier").equal("examplecluster") |     assert result["ClusterIdentifier"] == "examplecluster" | ||||||
|     result.should.have.key("Endpoint").should.have.key("Address").match( |     assert re.match( | ||||||
|         "examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com" |         "examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com", | ||||||
|  |         result["Endpoint"]["Address"], | ||||||
|     ) |     ) | ||||||
|     result.should.have.key("Endpoint").should.have.key("Port").equal( |     assert result["Endpoint"]["Port"] == (1234 if is_json else "1234") | ||||||
|         1234 if is_json else "1234" |     assert "ClusterCreateTime" in result | ||||||
|     ) |     assert "Tags" in result | ||||||
|     result.should.have.key("ClusterCreateTime") |  | ||||||
|     result.should.have.key("Tags") |  | ||||||
|     tags = result["Tags"] |     tags = result["Tags"] | ||||||
|     if not is_json: |     if not is_json: | ||||||
|         tags = tags["item"] |         tags = tags["item"] | ||||||
|     tags.should.equal( |     assert tags == [{"Key": "key1", "Value": "val1"}, {"Key": "key2", "Value": "val2"}] | ||||||
|         [{"Key": "key1", "Value": "val1"}, {"Key": "key2", "Value": "val2"}] |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -185,18 +179,16 @@ def test_create_and_describe_clusters(is_json): | |||||||
|         result = xmltodict.parse(result, dict_constructor=dict) |         result = xmltodict.parse(result, dict_constructor=dict) | ||||||
| 
 | 
 | ||||||
|     del result["DescribeClustersResponse"]["ResponseMetadata"] |     del result["DescribeClustersResponse"]["ResponseMetadata"] | ||||||
|     result.should.have.key("DescribeClustersResponse") |     assert "DescribeClustersResponse" in result | ||||||
|     result["DescribeClustersResponse"].should.have.key("DescribeClustersResult") |     assert "DescribeClustersResult" in result["DescribeClustersResponse"] | ||||||
|     result["DescribeClustersResponse"]["DescribeClustersResult"].should.have.key( |     assert "Clusters" in result["DescribeClustersResponse"]["DescribeClustersResult"] | ||||||
|         "Clusters" |  | ||||||
|     ) |  | ||||||
|     result = result["DescribeClustersResponse"]["DescribeClustersResult"]["Clusters"] |     result = result["DescribeClustersResponse"]["DescribeClustersResult"]["Clusters"] | ||||||
|     if not is_json: |     if not is_json: | ||||||
|         result = result["item"] |         result = result["item"] | ||||||
| 
 | 
 | ||||||
|     result.should.have.length_of(2) |     assert len(result) == 2 | ||||||
|     for cluster in result: |     for cluster in result: | ||||||
|         cluster_names.should.contain(cluster["ClusterIdentifier"]) |         assert cluster["ClusterIdentifier"] in cluster_names | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -233,12 +225,12 @@ def test_create_and_describe_cluster_security_group(is_json): | |||||||
|         groups = groups["item"] |         groups = groups["item"] | ||||||
| 
 | 
 | ||||||
|     descriptions = [g["Description"] for g in groups] |     descriptions = [g["Description"] for g in groups] | ||||||
|     descriptions.should.contain("desc for csg1") |     assert "desc for csg1" in descriptions | ||||||
|     descriptions.should.contain("desc for csg2") |     assert "desc for csg2" in descriptions | ||||||
| 
 | 
 | ||||||
|     # Describe single SG |     # Describe single SG | ||||||
|     describe_params = ( |     describe_params = ( | ||||||
|         "/?Action=DescribeClusterSecurityGroups" "&ClusterSecurityGroupName=csg1" |         "/?Action=DescribeClusterSecurityGroups&ClusterSecurityGroupName=csg1" | ||||||
|     ) |     ) | ||||||
|     if is_json: |     if is_json: | ||||||
|         describe_params += "&ContentType=JSON" |         describe_params += "&ContentType=JSON" | ||||||
| @ -255,10 +247,10 @@ def test_create_and_describe_cluster_security_group(is_json): | |||||||
|     ]["ClusterSecurityGroups"] |     ]["ClusterSecurityGroups"] | ||||||
| 
 | 
 | ||||||
|     if is_json: |     if is_json: | ||||||
|         groups.should.have.length_of(1) |         assert len(groups) == 1 | ||||||
|         groups[0]["ClusterSecurityGroupName"].should.equal("csg1") |         assert groups[0]["ClusterSecurityGroupName"] == "csg1" | ||||||
|     else: |     else: | ||||||
|         groups["item"]["ClusterSecurityGroupName"].should.equal("csg1") |         assert groups["item"]["ClusterSecurityGroupName"] == "csg1" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -268,13 +260,13 @@ def test_describe_unknown_cluster_security_group(is_json): | |||||||
|     test_client = backend.test_client() |     test_client = backend.test_client() | ||||||
| 
 | 
 | ||||||
|     describe_params = ( |     describe_params = ( | ||||||
|         "/?Action=DescribeClusterSecurityGroups" "&ClusterSecurityGroupName=unknown" |         "/?Action=DescribeClusterSecurityGroups&ClusterSecurityGroupName=unknown" | ||||||
|     ) |     ) | ||||||
|     if is_json: |     if is_json: | ||||||
|         describe_params += "&ContentType=JSON" |         describe_params += "&ContentType=JSON" | ||||||
|     res = test_client.get(describe_params) |     res = test_client.get(describe_params) | ||||||
| 
 | 
 | ||||||
|     res.status_code.should.equal(400) |     assert res.status_code == 400 | ||||||
| 
 | 
 | ||||||
|     if is_json: |     if is_json: | ||||||
|         response = json.loads(res.data.decode("utf-8")) |         response = json.loads(res.data.decode("utf-8")) | ||||||
| @ -284,8 +276,8 @@ def test_describe_unknown_cluster_security_group(is_json): | |||||||
|         ] |         ] | ||||||
|     error = response["Error"] |     error = response["Error"] | ||||||
| 
 | 
 | ||||||
|     error["Code"].should.equal("ClusterSecurityGroupNotFound") |     assert error["Code"] == "ClusterSecurityGroupNotFound" | ||||||
|     error["Message"].should.equal("Security group unknown not found.") |     assert error["Message"] == "Security group unknown not found." | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | @pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"]) | ||||||
| @ -312,15 +304,15 @@ def test_create_cluster_with_security_group(is_json): | |||||||
|             response = xmltodict.parse(response, dict_constructor=dict) |             response = xmltodict.parse(response, dict_constructor=dict) | ||||||
| 
 | 
 | ||||||
|         del response["CreateClusterSecurityGroupResponse"]["ResponseMetadata"] |         del response["CreateClusterSecurityGroupResponse"]["ResponseMetadata"] | ||||||
|         response.should.have.key("CreateClusterSecurityGroupResponse") |         assert "CreateClusterSecurityGroupResponse" in response | ||||||
|         response = response["CreateClusterSecurityGroupResponse"] |         response = response["CreateClusterSecurityGroupResponse"] | ||||||
|         response.should.have.key("CreateClusterSecurityGroupResult") |         assert "CreateClusterSecurityGroupResult" in response | ||||||
|         result = response["CreateClusterSecurityGroupResult"] |         result = response["CreateClusterSecurityGroupResult"] | ||||||
|         result.should.have.key("ClusterSecurityGroup") |         assert "ClusterSecurityGroup" in result | ||||||
|         sg = result["ClusterSecurityGroup"] |         sg = result["ClusterSecurityGroup"] | ||||||
|         sg.should.have.key("ClusterSecurityGroupName").being.equal(csg) |         assert sg["ClusterSecurityGroupName"] == csg | ||||||
|         sg.should.have.key("Description").being.equal("desc for " + csg) |         assert sg["Description"] == "desc for " + csg | ||||||
|         sg.should.have.key("EC2SecurityGroups").being.equal([] if is_json else None) |         assert sg["EC2SecurityGroups"] == ([] if is_json else None) | ||||||
| 
 | 
 | ||||||
|     # Create Cluster with these security groups |     # Create Cluster with these security groups | ||||||
|     create_params = ( |     create_params = ( | ||||||
| @ -344,17 +336,15 @@ def test_create_cluster_with_security_group(is_json): | |||||||
|         result = xmltodict.parse(result, dict_constructor=dict) |         result = xmltodict.parse(result, dict_constructor=dict) | ||||||
| 
 | 
 | ||||||
|     del result["CreateClusterResponse"]["ResponseMetadata"] |     del result["CreateClusterResponse"]["ResponseMetadata"] | ||||||
|     result.should.have.key("CreateClusterResponse") |     assert "CreateClusterResponse" in result | ||||||
|     result["CreateClusterResponse"].should.have.key("CreateClusterResult") |     assert "CreateClusterResult" in result["CreateClusterResponse"] | ||||||
|     result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster") |     assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"] | ||||||
|     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] |     result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"] | ||||||
| 
 | 
 | ||||||
|     security_groups = result["ClusterSecurityGroups"] |     security_groups = result["ClusterSecurityGroups"] | ||||||
|     if not is_json: |     if not is_json: | ||||||
|         security_groups = security_groups["item"] |         security_groups = security_groups["item"] | ||||||
| 
 | 
 | ||||||
|     security_groups.should.have.length_of(2) |     assert len(security_groups) == 2 | ||||||
|     for csg in security_group_names: |     for csg in security_group_names: | ||||||
|         security_groups.should.contain( |         assert {"ClusterSecurityGroupName": csg, "Status": "active"} in security_groups | ||||||
|             {"ClusterSecurityGroupName": csg, "Status": "active"} |  | ||||||
|         ) |  | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user