Techdebt: Replace sure with regular assertions in RedShift (#6690)

This commit is contained in:
kbalk 2023-08-19 13:58:06 -04:00 committed by GitHub
parent f57bb251c8
commit ca83236da6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 591 additions and 555 deletions

View File

@ -1,14 +1,15 @@
import string
import re
import responses
import requests
import time
from datetime import datetime
from collections import defaultdict
from openapi_spec_validator import validate_spec
from datetime import datetime
import re
import string
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from openapi_spec_validator import validate_spec
import requests
import responses
try:
from openapi_spec_validator.validation.exceptions import OpenAPIValidationError
except ImportError:
@ -1511,7 +1512,7 @@ class APIGatewayBackend(BaseBackend):
integrationHttpMethod="GET"
)
deploy_url = f"https://{api_id}.execute-api.us-east-1.amazonaws.com/dev"
requests.get(deploy_url).content.should.equal(b"a fake response")
assert requests.get(deploy_url).content == b"a fake response"
Limitations:
- Integrations of type HTTP are supported

View File

@ -1,12 +1,13 @@
import time
from datetime import datetime
import time
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
from moto.utilities.paginator import paginate
from typing import Any, Dict, List, Optional
class TaggableResourceMixin(object):
class TaggableResourceMixin:
# This mixing was copied from Redshift when initially implementing
# Athena. TBD if it's worth the overhead.
@ -263,7 +264,7 @@ class AthenaBackend(BaseBackend):
"http://motoapi.amazonaws.com:5000/moto-api/static/athena/query-results",
json=expected_results,
)
resp.status_code.should.equal(201)
assert resp.status_code == 201
client = boto3.client("athena", region_name="us-east-1")
details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"]

View File

@ -1,7 +1,7 @@
from moto.core import BaseBackend, BackendDict
from typing import Any, Dict, List, Optional, Tuple
from moto.core import BaseBackend, BackendDict
class QueryResults:
def __init__(
@ -66,7 +66,7 @@ class RDSDataServiceBackend(BaseBackend):
"http://motoapi.amazonaws.com:5000/moto-api/static/rds-data/statement-results",
json=expected_results,
)
resp.status_code.should.equal(201)
assert resp.status_code == 201
rdsdata = boto3.client("rds-data", region_name="us-east-1")
resp = rdsdata.execute_statement(resourceArn="not applicable", secretArn="not applicable", sql="SELECT some FROM thing")

View File

@ -1,8 +1,10 @@
from collections import OrderedDict
import copy
import datetime
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional
from dateutil.tz import tzutc
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.ec2 import ec2_backends
@ -46,7 +48,10 @@ class TaggableResourceMixin:
@property
def arn(self) -> str:
return f"arn:aws:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}"
return (
f"arn:aws:redshift:{self.region}:{self.account_id}"
f":{self.resource_type}:{self.resource_id}"
)
def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags]
@ -95,7 +100,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier
self.create_time = iso_8601_datetime_with_milliseconds(
datetime.datetime.utcnow()
datetime.datetime.now(tzutc())
)
self.status = "available"
self.node_type = node_type
@ -220,7 +225,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
if attribute_name == "Endpoint.Address":
return self.endpoint
elif attribute_name == "Endpoint.Port":
if attribute_name == "Endpoint.Port":
return self.port
raise UnformattedGetAttTemplateException()
@ -528,7 +533,7 @@ class Snapshot(TaggableResourceMixin, BaseModel):
self.snapshot_type = snapshot_type
self.status = "available"
self.create_time = iso_8601_datetime_with_milliseconds(
datetime.datetime.utcnow()
datetime.datetime.now(tzutc())
)
self.iam_roles_arn = iam_roles_arn or []
@ -622,7 +627,6 @@ class RedshiftBackend(BaseBackend):
}
cluster.cluster_snapshot_copy_status = status
return cluster
else:
raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier)
def disable_snapshot_copy(self, **kwargs: Any) -> Cluster:
@ -631,7 +635,6 @@ class RedshiftBackend(BaseBackend):
if cluster.cluster_snapshot_copy_status is not None:
cluster.cluster_snapshot_copy_status = None
return cluster
else:
raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier)
def modify_snapshot_copy_retention_period(
@ -650,7 +653,10 @@ class RedshiftBackend(BaseBackend):
raise ClusterAlreadyExistsFaultError()
cluster = Cluster(self, **cluster_kwargs)
self.clusters[cluster_identifier] = cluster
snapshot_id = f"rs:{cluster_identifier}-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}"
snapshot_id = (
f"rs:{cluster_identifier}-"
f"{datetime.datetime.now(tzutc()).strftime('%Y-%m-%d-%H-%M')}"
)
# Automated snapshots don't copy over the tags
self.create_cluster_snapshot(
cluster_identifier,
@ -679,7 +685,6 @@ class RedshiftBackend(BaseBackend):
if cluster_identifier:
if cluster_identifier in self.clusters:
return [self.clusters[cluster_identifier]]
else:
raise ClusterNotFoundError(cluster_identifier)
return list(self.clusters.values())
@ -739,9 +744,10 @@ class RedshiftBackend(BaseBackend):
and cluster_snapshot_identifer is None
):
raise InvalidParameterCombinationError(
"FinalClusterSnapshotIdentifier is required unless SkipFinalClusterSnapshot is specified."
"FinalClusterSnapshotIdentifier is required unless "
"SkipFinalClusterSnapshot is specified."
)
elif (
if (
cluster_skip_final_snapshot is False
and cluster_snapshot_identifer is not None
): # create snapshot
@ -781,7 +787,6 @@ class RedshiftBackend(BaseBackend):
if subnet_identifier:
if subnet_identifier in self.subnet_groups:
return [self.subnet_groups[subnet_identifier]]
else:
raise ClusterSubnetGroupNotFoundError(subnet_identifier)
return list(self.subnet_groups.values())
@ -812,7 +817,6 @@ class RedshiftBackend(BaseBackend):
if security_group_name:
if security_group_name in self.security_groups:
return [self.security_groups[security_group_name]]
else:
raise ClusterSecurityGroupNotFoundError(security_group_name)
return list(self.security_groups.values())
@ -861,7 +865,6 @@ class RedshiftBackend(BaseBackend):
if parameter_group_name:
if parameter_group_name in self.parameter_groups:
return [self.parameter_groups[parameter_group_name]]
else:
raise ClusterParameterGroupNotFoundError(parameter_group_name)
return list(self.parameter_groups.values())
@ -983,7 +986,6 @@ class RedshiftBackend(BaseBackend):
if snapshot_copy_grant_name:
if snapshot_copy_grant_name in self.snapshot_copy_grants:
return [self.snapshot_copy_grants[snapshot_copy_grant_name]]
else:
raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name)
return copy_grants
@ -1000,15 +1002,15 @@ class RedshiftBackend(BaseBackend):
resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if resources is None:
message = (
f"Tagging is not supported for this type of resource: '{resource_type}' "
"(the ARN is potentially malformed, please check the ARN documentation for more information)"
"Tagging is not supported for this type of resource: "
f"'{resource_type}' (the ARN is potentially malformed, "
"please check the ARN documentation for more information)"
)
raise ResourceNotFoundFaultError(message=message)
try:
resource = resources[resource_id]
except KeyError:
raise ResourceNotFoundFaultError(resource_type, resource_id)
else:
return resource
@staticmethod
@ -1081,7 +1083,7 @@ class RedshiftBackend(BaseBackend):
raise InvalidParameterValueError(
"Token duration must be between 900 and 3600 seconds"
)
expiration = datetime.datetime.utcnow() + datetime.timedelta(
expiration = datetime.datetime.now(tzutc()) + datetime.timedelta(
0, duration_seconds
)
if cluster_identifier in self.clusters:
@ -1092,7 +1094,6 @@ class RedshiftBackend(BaseBackend):
"DbPassword": mock_random.get_random_string(32),
"Expiration": expiration,
}
else:
raise ClusterNotFoundError(cluster_identifier)

View File

@ -3,6 +3,5 @@ pytest
pytest-cov
pytest-ordering
pytest-xdist
surer
freezegun
pylint

View File

@ -1,15 +1,17 @@
from unittest import SkipTest
import boto3
import pytest
from moto import mock_s3
from moto import settings
from unittest import SkipTest
@pytest.fixture(scope="function", name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
if settings.TEST_SERVER_MODE:
raise SkipTest("No point in testing this in ServerMode.")
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
@ -63,8 +65,7 @@ def test_mock_works_with_resource_created_outside(
patch_resource(outside_resource)
b = list(outside_resource.buckets.all())
assert b == []
assert not list(outside_resource.buckets.all())
m.stop()
@ -126,7 +127,7 @@ def test_mock_works_when_replacing_client(
try:
logic.do_important_things()
except Exception as e:
str(e).should.contain("InvalidAccessKeyId")
assert str(e) == "InvalidAccessKeyId"
client_initialized_after_mock = boto3.client("s3", region_name="us-east-1")
logic._s3 = client_initialized_after_mock

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
import boto3
import json
import sure # noqa # pylint: disable=unused-import
import boto3
from moto import mock_cloudformation, mock_ec2, mock_redshift
from tests.test_cloudformation.fixtures import redshift
@ -33,20 +33,20 @@ def test_redshift_stack():
cluster_res = redshift_conn.describe_clusters()
clusters = cluster_res["Clusters"]
clusters.should.have.length_of(1)
assert len(clusters) == 1
cluster = clusters[0]
cluster["DBName"].should.equal("mydb")
cluster["NumberOfNodes"].should.equal(2)
cluster["NodeType"].should.equal("dw1.xlarge")
cluster["MasterUsername"].should.equal("myuser")
cluster["Endpoint"]["Port"].should.equal(5439)
cluster["VpcSecurityGroups"].should.have.length_of(1)
assert cluster["DBName"] == "mydb"
assert cluster["NumberOfNodes"] == 2
assert cluster["NodeType"] == "dw1.xlarge"
assert cluster["MasterUsername"] == "myuser"
assert cluster["Endpoint"]["Port"] == 5439
assert len(cluster["VpcSecurityGroups"]) == 1
security_group_id = cluster["VpcSecurityGroups"][0]["VpcSecurityGroupId"]
groups = ec2.describe_security_groups(GroupIds=[security_group_id])[
"SecurityGroups"
]
groups.should.have.length_of(1)
assert len(groups) == 1
group = groups[0]
group["IpPermissions"].should.have.length_of(1)
group["IpPermissions"][0]["IpRanges"][0]["CidrIp"].should.equal("10.0.0.1/16")
assert len(group["IpPermissions"]) == 1
assert group["IpPermissions"][0]["IpRanges"][0]["CidrIp"] == "10.0.0.1/16"

View File

@ -1,15 +1,13 @@
import sure # noqa # pylint: disable=unused-import
"""Test the different server responses."""
import json
import re
import pytest
import xmltodict
import moto.server as server
from moto import mock_redshift
"""
Test the different server responses
"""
@mock_redshift
def test_describe_clusters():
@ -19,7 +17,7 @@ def test_describe_clusters():
res = test_client.get("/?Action=DescribeClusters")
result = res.data.decode("utf-8")
result.should.contain("<Clusters></Clusters>")
assert "<Clusters></Clusters>" in result
@mock_redshift
@ -31,9 +29,9 @@ def test_describe_clusters_with_json_content_type():
result = json.loads(res.data.decode("utf-8"))
del result["DescribeClustersResponse"]["ResponseMetadata"]
result.should.equal(
{"DescribeClustersResponse": {"DescribeClustersResult": {"Clusters": []}}}
)
assert result == {
"DescribeClustersResponse": {"DescribeClustersResult": {"Clusters": []}}
}
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -61,30 +59,29 @@ def test_create_cluster(is_json):
result = xmltodict.parse(result, dict_constructor=dict)
del result["CreateClusterResponse"]["ResponseMetadata"]
result.should.have.key("CreateClusterResponse")
result["CreateClusterResponse"].should.have.key("CreateClusterResult")
result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster")
assert "CreateClusterResponse" in result
assert "CreateClusterResult" in result["CreateClusterResponse"]
assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"]
result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"]
result.should.have.key("MasterUsername").equal("masteruser")
result.should.have.key("MasterUserPassword").equal("****")
result.should.have.key("ClusterVersion").equal("1.0")
result.should.have.key("ClusterSubnetGroupName").equal(None)
result.should.have.key("AvailabilityZone").equal("us-east-1a")
result.should.have.key("ClusterStatus").equal("creating")
result.should.have.key("NumberOfNodes").equal(1 if is_json else "1")
result.should.have.key("PubliclyAccessible").equal(None)
result.should.have.key("Encrypted").equal(None)
result.should.have.key("DBName").equal("dev")
result.should.have.key("NodeType").equal("ds2.xlarge")
result.should.have.key("ClusterIdentifier").equal("examplecluster")
result.should.have.key("Endpoint").should.have.key("Address").match(
"examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com"
assert result["MasterUsername"] == "masteruser"
assert result["MasterUserPassword"] == "****"
assert result["ClusterVersion"] == "1.0"
assert result["ClusterSubnetGroupName"] is None
assert result["AvailabilityZone"] == "us-east-1a"
assert result["ClusterStatus"] == "creating"
assert result["NumberOfNodes"] == (1 if is_json else "1")
assert result["PubliclyAccessible"] is None
assert result["Encrypted"] is None
assert result["DBName"] == "dev"
assert result["NodeType"] == "ds2.xlarge"
assert result["ClusterIdentifier"] == "examplecluster"
assert re.match(
"examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com",
result["Endpoint"]["Address"],
)
result.should.have.key("Endpoint").should.have.key("Port").equal(
5439 if is_json else "5439"
)
result.should.have.key("ClusterCreateTime")
assert result["Endpoint"]["Port"] == (5439 if is_json else "5439")
assert "ClusterCreateTime" in result
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -122,37 +119,34 @@ def test_create_cluster_multiple_params(is_json):
result = xmltodict.parse(result, dict_constructor=dict)
del result["CreateClusterResponse"]["ResponseMetadata"]
result.should.have.key("CreateClusterResponse")
result["CreateClusterResponse"].should.have.key("CreateClusterResult")
result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster")
assert "CreateClusterResponse" in result
assert "CreateClusterResult" in result["CreateClusterResponse"]
assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"]
result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"]
result.should.have.key("MasterUsername").equal("masteruser")
result.should.have.key("MasterUserPassword").equal("****")
result.should.have.key("ClusterVersion").equal("2.0")
result.should.have.key("ClusterSubnetGroupName").equal(None)
result.should.have.key("AvailabilityZone").equal("us-east-1a")
result.should.have.key("ClusterStatus").equal("creating")
result.should.have.key("NumberOfNodes").equal(3 if is_json else "3")
result.should.have.key("PubliclyAccessible").equal(None)
result.should.have.key("Encrypted").equal("True")
result.should.have.key("DBName").equal("testdb")
result.should.have.key("NodeType").equal("ds2.xlarge")
result.should.have.key("ClusterIdentifier").equal("examplecluster")
result.should.have.key("Endpoint").should.have.key("Address").match(
"examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com"
assert result["MasterUsername"] == "masteruser"
assert result["MasterUserPassword"] == "****"
assert result["ClusterVersion"] == "2.0"
assert result["ClusterSubnetGroupName"] is None
assert result["AvailabilityZone"] == "us-east-1a"
assert result["ClusterStatus"] == "creating"
assert result["NumberOfNodes"] == (3 if is_json else "3")
assert result["PubliclyAccessible"] is None
assert result["Encrypted"] == "True"
assert result["DBName"] == "testdb"
assert result["NodeType"] == "ds2.xlarge"
assert result["ClusterIdentifier"] == "examplecluster"
assert re.match(
"examplecluster.[a-z0-9]+.us-east-1.redshift.amazonaws.com",
result["Endpoint"]["Address"],
)
result.should.have.key("Endpoint").should.have.key("Port").equal(
1234 if is_json else "1234"
)
result.should.have.key("ClusterCreateTime")
result.should.have.key("Tags")
assert result["Endpoint"]["Port"] == (1234 if is_json else "1234")
assert "ClusterCreateTime" in result
assert "Tags" in result
tags = result["Tags"]
if not is_json:
tags = tags["item"]
tags.should.equal(
[{"Key": "key1", "Value": "val1"}, {"Key": "key2", "Value": "val2"}]
)
assert tags == [{"Key": "key1", "Value": "val1"}, {"Key": "key2", "Value": "val2"}]
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -185,18 +179,16 @@ def test_create_and_describe_clusters(is_json):
result = xmltodict.parse(result, dict_constructor=dict)
del result["DescribeClustersResponse"]["ResponseMetadata"]
result.should.have.key("DescribeClustersResponse")
result["DescribeClustersResponse"].should.have.key("DescribeClustersResult")
result["DescribeClustersResponse"]["DescribeClustersResult"].should.have.key(
"Clusters"
)
assert "DescribeClustersResponse" in result
assert "DescribeClustersResult" in result["DescribeClustersResponse"]
assert "Clusters" in result["DescribeClustersResponse"]["DescribeClustersResult"]
result = result["DescribeClustersResponse"]["DescribeClustersResult"]["Clusters"]
if not is_json:
result = result["item"]
result.should.have.length_of(2)
assert len(result) == 2
for cluster in result:
cluster_names.should.contain(cluster["ClusterIdentifier"])
assert cluster["ClusterIdentifier"] in cluster_names
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -233,12 +225,12 @@ def test_create_and_describe_cluster_security_group(is_json):
groups = groups["item"]
descriptions = [g["Description"] for g in groups]
descriptions.should.contain("desc for csg1")
descriptions.should.contain("desc for csg2")
assert "desc for csg1" in descriptions
assert "desc for csg2" in descriptions
# Describe single SG
describe_params = (
"/?Action=DescribeClusterSecurityGroups" "&ClusterSecurityGroupName=csg1"
"/?Action=DescribeClusterSecurityGroups&ClusterSecurityGroupName=csg1"
)
if is_json:
describe_params += "&ContentType=JSON"
@ -255,10 +247,10 @@ def test_create_and_describe_cluster_security_group(is_json):
]["ClusterSecurityGroups"]
if is_json:
groups.should.have.length_of(1)
groups[0]["ClusterSecurityGroupName"].should.equal("csg1")
assert len(groups) == 1
assert groups[0]["ClusterSecurityGroupName"] == "csg1"
else:
groups["item"]["ClusterSecurityGroupName"].should.equal("csg1")
assert groups["item"]["ClusterSecurityGroupName"] == "csg1"
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -268,13 +260,13 @@ def test_describe_unknown_cluster_security_group(is_json):
test_client = backend.test_client()
describe_params = (
"/?Action=DescribeClusterSecurityGroups" "&ClusterSecurityGroupName=unknown"
"/?Action=DescribeClusterSecurityGroups&ClusterSecurityGroupName=unknown"
)
if is_json:
describe_params += "&ContentType=JSON"
res = test_client.get(describe_params)
res.status_code.should.equal(400)
assert res.status_code == 400
if is_json:
response = json.loads(res.data.decode("utf-8"))
@ -284,8 +276,8 @@ def test_describe_unknown_cluster_security_group(is_json):
]
error = response["Error"]
error["Code"].should.equal("ClusterSecurityGroupNotFound")
error["Message"].should.equal("Security group unknown not found.")
assert error["Code"] == "ClusterSecurityGroupNotFound"
assert error["Message"] == "Security group unknown not found."
@pytest.mark.parametrize("is_json", [True, False], ids=["JSON", "XML"])
@ -312,15 +304,15 @@ def test_create_cluster_with_security_group(is_json):
response = xmltodict.parse(response, dict_constructor=dict)
del response["CreateClusterSecurityGroupResponse"]["ResponseMetadata"]
response.should.have.key("CreateClusterSecurityGroupResponse")
assert "CreateClusterSecurityGroupResponse" in response
response = response["CreateClusterSecurityGroupResponse"]
response.should.have.key("CreateClusterSecurityGroupResult")
assert "CreateClusterSecurityGroupResult" in response
result = response["CreateClusterSecurityGroupResult"]
result.should.have.key("ClusterSecurityGroup")
assert "ClusterSecurityGroup" in result
sg = result["ClusterSecurityGroup"]
sg.should.have.key("ClusterSecurityGroupName").being.equal(csg)
sg.should.have.key("Description").being.equal("desc for " + csg)
sg.should.have.key("EC2SecurityGroups").being.equal([] if is_json else None)
assert sg["ClusterSecurityGroupName"] == csg
assert sg["Description"] == "desc for " + csg
assert sg["EC2SecurityGroups"] == ([] if is_json else None)
# Create Cluster with these security groups
create_params = (
@ -344,17 +336,15 @@ def test_create_cluster_with_security_group(is_json):
result = xmltodict.parse(result, dict_constructor=dict)
del result["CreateClusterResponse"]["ResponseMetadata"]
result.should.have.key("CreateClusterResponse")
result["CreateClusterResponse"].should.have.key("CreateClusterResult")
result["CreateClusterResponse"]["CreateClusterResult"].should.have.key("Cluster")
assert "CreateClusterResponse" in result
assert "CreateClusterResult" in result["CreateClusterResponse"]
assert "Cluster" in result["CreateClusterResponse"]["CreateClusterResult"]
result = result["CreateClusterResponse"]["CreateClusterResult"]["Cluster"]
security_groups = result["ClusterSecurityGroups"]
if not is_json:
security_groups = security_groups["item"]
security_groups.should.have.length_of(2)
assert len(security_groups) == 2
for csg in security_group_names:
security_groups.should.contain(
{"ClusterSecurityGroupName": csg, "Status": "active"}
)
assert {"ClusterSecurityGroupName": csg, "Status": "active"} in security_groups