From 2ec1a0477888d5dcc4f2cde95bc6c9502f7762d6 Mon Sep 17 00:00:00 2001 From: vincbeck <97131062+vincbeck@users.noreply.github.com> Date: Fri, 11 Feb 2022 15:11:22 -0500 Subject: [PATCH] Add redshift-data: cancel_statement, describe_statement, execute_statement, get_statement_result (#4832) --- IMPLEMENTATION_COVERAGE.md | 19 ++ docs/docs/services/redshift-data.rst | 40 +++ moto/__init__.py | 3 + moto/backend_index.py | 1 + moto/redshiftdata/__init__.py | 4 + moto/redshiftdata/exceptions.py | 11 + moto/redshiftdata/models.py | 217 ++++++++++++++++ moto/redshiftdata/responses.py | 60 +++++ moto/redshiftdata/urls.py | 10 + tests/helpers.py | 11 + tests/test_redshiftdata/__init__.py | 0 tests/test_redshiftdata/test_redshiftdata.py | 182 ++++++++++++++ .../test_redshiftdata_constants.py | 11 + tests/test_redshiftdata/test_server.py | 233 ++++++++++++++++++ 14 files changed, 802 insertions(+) create mode 100644 docs/docs/services/redshift-data.rst create mode 100644 moto/redshiftdata/__init__.py create mode 100644 moto/redshiftdata/exceptions.py create mode 100644 moto/redshiftdata/models.py create mode 100644 moto/redshiftdata/responses.py create mode 100644 moto/redshiftdata/urls.py create mode 100644 tests/test_redshiftdata/__init__.py create mode 100644 tests/test_redshiftdata/test_redshiftdata.py create mode 100644 tests/test_redshiftdata/test_redshiftdata_constants.py create mode 100644 tests/test_redshiftdata/test_server.py diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 47e821c5c..946ad7970 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -4150,6 +4150,25 @@ - [ ] update_partner_status +## redshift-data +
+30% implemented + +- [ ] batch_execute_statement +- [ ] can_paginate +- [X] cancel_statement +- [X] describe_statement +- [ ] describe_table +- [X] execute_statement +- [ ] get_paginator +- [X] get_statement_result +- [ ] get_waiter +- [ ] list_databases +- [ ] list_schemas +- [ ] list_statements +- [ ] list_tables +
+ ## resource-groups
68% implemented diff --git a/docs/docs/services/redshift-data.rst b/docs/docs/services/redshift-data.rst new file mode 100644 index 000000000..f998549c0 --- /dev/null +++ b/docs/docs/services/redshift-data.rst @@ -0,0 +1,40 @@ +.. _implementedservice_redshift-data: + +.. |start-h3| raw:: html + +

+ +.. |end-h3| raw:: html + +

+ +======== +redshift-data +======== + +|start-h3| Example usage |end-h3| + +.. sourcecode:: python + + @mock_redshiftdata + def test_redshift_behaviour: + boto3.client("redshift-data") + ... + + + +|start-h3| Implemented features for this service |end-h3| + +- [ ] batch_execute_statement +- [ ] can_paginate +- [X] cancel_statement +- [X] describe_statement +- [ ] describe_table +- [X] execute_statement +- [ ] get_paginator +- [X] get_statement_result +- [ ] get_waiter +- [ ] list_databases +- [ ] list_schemas +- [ ] list_statements +- [ ] list_tables diff --git a/moto/__init__.py b/moto/__init__.py index 9ad1baeac..3178dc64c 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -111,6 +111,9 @@ mock_ram = lazy_load(".ram", "mock_ram") mock_rds = lazy_load(".rds", "mock_rds", warn_repurpose=True) mock_rds2 = lazy_load(".rds2", "mock_rds2", boto3_name="rds") mock_redshift = lazy_load(".redshift", "mock_redshift") +mock_redshiftdata = lazy_load( + ".redshiftdata", "mock_redshiftdata", boto3_name="redshift-data" +) mock_resourcegroups = lazy_load( ".resourcegroups", "mock_resourcegroups", boto3_name="resource-groups" ) diff --git a/moto/backend_index.py b/moto/backend_index.py index b14525e38..f58ad44d5 100644 --- a/moto/backend_index.py +++ b/moto/backend_index.py @@ -99,6 +99,7 @@ backend_url_patterns = [ ("rds", re.compile("https?://rds\\.(.+)\\.amazonaws\\.com")), ("rds", re.compile("https?://rds\\.amazonaws\\.com")), ("redshift", re.compile("https?://redshift\\.(.+)\\.amazonaws\\.com")), + ("redshift-data", re.compile("https?://redshift-data\\.(.+)\\.amazonaws\\.com")), ( "resource-groups", re.compile("https?://resource-groups(-fips)?\\.(.+)\\.amazonaws.com"), diff --git a/moto/redshiftdata/__init__.py b/moto/redshiftdata/__init__.py new file mode 100644 index 000000000..57a653521 --- /dev/null +++ b/moto/redshiftdata/__init__.py @@ -0,0 +1,4 @@ +from .models import redshiftdata_backends +from ..core.models import base_decorator + +mock_redshiftdata = base_decorator(redshiftdata_backends) diff --git a/moto/redshiftdata/exceptions.py b/moto/redshiftdata/exceptions.py new file mode 100644 index 000000000..e3b721223 --- /dev/null +++ b/moto/redshiftdata/exceptions.py @@ -0,0 +1,11 @@ +from moto.core.exceptions import JsonRESTError + + +class ResourceNotFoundException(JsonRESTError): + def __init__(self): + super().__init__("ResourceNotFoundException", "Query does not exist.") + + +class ValidationException(JsonRESTError): + def __init__(self, message): + super().__init__("ValidationException", message) diff --git a/moto/redshiftdata/models.py b/moto/redshiftdata/models.py new file mode 100644 index 000000000..b81ac6bd4 --- /dev/null +++ b/moto/redshiftdata/models.py @@ -0,0 +1,217 @@ +import re +import uuid +from datetime import datetime +import random + +from moto.core import BaseBackend +from moto.core.utils import BackendDict, iso_8601_datetime_without_milliseconds +from moto.redshiftdata.exceptions import ValidationException, ResourceNotFoundException + + +class Statement: + def __init__( + self, + cluster_identifier, + database, + db_user, + query_parameters, + query_string, + secret_arn, + ): + now = iso_8601_datetime_without_milliseconds(datetime.now()) + + self.id = str(uuid.uuid4()) + self.cluster_identifier = cluster_identifier + self.created_at = now + self.database = database + self.db_user = db_user + self.duration = 0 + self.has_result_set = False + self.query_parameters = query_parameters + self.query_string = query_string + self.redshift_pid = random.randint(0, 99999) + self.redshift_query_id = random.randint(0, 99999) + self.result_rows = -1 + self.result_size = -1 + self.secret_arn = secret_arn + self.status = "STARTED" + self.sub_statements = [] + self.updated_at = now + + def __iter__(self): + yield "Id", self.id + yield "ClusterIdentifier", self.cluster_identifier + yield "CreatedAt", self.created_at + yield "Database", self.database + yield "DbUser", self.db_user + yield "Duration", self.duration + yield "HasResultSet", self.has_result_set + yield "QueryParameters", self.query_parameters + yield "QueryString", self.query_string + yield "RedshiftPid", self.redshift_pid + yield "RedshiftQueryId", self.redshift_query_id + yield "ResultRows", self.result_rows + yield "ResultSize", self.result_size + yield "SecretArn", self.secret_arn + yield "Status", self.status + yield "SubStatements", self.sub_statements + yield "UpdatedAt", self.updated_at + + +class StatementResult: + def __init__( + self, column_metadata, records, total_number_rows, next_token=None, + ): + self.column_metadata = column_metadata + self.records = records + self.total_number_rows = total_number_rows + self.next_token = next_token + + def __iter__(self): + yield "ColumnMetadata", self.column_metadata + yield "Records", self.records + yield "TotalNumberRows", self.total_number_rows + yield "NextToken", self.next_token + + +class ColumnMetadata: + def __init__(self, column_default, is_case_sensitive, is_signed, name, nullable): + self.column_default = column_default + self.is_case_sensitive = is_case_sensitive + self.is_signed = is_signed + self.name = name + self.nullable = nullable + + def __iter__(self): + yield "columnDefault", self.column_default + yield "isCaseSensitive", self.is_case_sensitive + yield "isSigned", self.is_signed + yield "name", self.name + yield "nullable", self.nullable + + +class Record: + def __init__( + self, **kwargs, + ): + self.kwargs = kwargs + + def __iter__(self): + if "long_value" in self.kwargs: + yield "longValue", self.kwargs["long_value"] + elif "string_value" in self.kwargs: + yield "stringValue", self.kwargs["string_value"] + + +class RedshiftDataAPIServiceBackend(BaseBackend): + def __init__(self, region_name=None): + self.region_name = region_name + self.statements = {} + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def cancel_statement(self, statement_id): + _validate_uuid(statement_id) + + try: + # Statement exists + statement = self.statements[statement_id] + + if statement.status != "STARTED": + raise ValidationException( + "Could not cancel a query that is already in %s state with ID: %s" + % (statement.status, statement_id) + ) + + statement.status = "ABORTED" + self.statements[statement_id] = statement + except KeyError: + # Statement does not exist. + raise ResourceNotFoundException() + + return True + + def describe_statement(self, statement_id): + _validate_uuid(statement_id) + + try: + # Statement exists + return self.statements[statement_id] + except KeyError: + # Statement does not exist. + raise ResourceNotFoundException() + + def execute_statement( + self, cluster_identifier, database, db_user, parameters, secret_arn, sql, + ): + """ + Runs an SQL statement + Validation of parameters is very limited because there is no redshift integration + """ + statement = Statement( + cluster_identifier=cluster_identifier, + database=database, + db_user=db_user, + query_parameters=parameters, + query_string=sql, + secret_arn=secret_arn, + ) + self.statements[statement.id] = statement + return statement + + def get_statement_result(self, statement_id): + """ + Return static statement result + StatementResult is the result of the SQL query "sql" passed as parameter when calling "execute_statement" + As such, it cannot be mocked + """ + _validate_uuid(statement_id) + + if statement_id not in self.statements: + raise ResourceNotFoundException() + + return StatementResult( + [ + dict(ColumnMetadata(None, False, True, "Number", False)), + dict(ColumnMetadata(None, True, False, "Street", False)), + dict(ColumnMetadata(None, True, False, "City", False)), + ], + [ + [ + dict(Record(long_value=10)), + dict(Record(string_value="Alpha st")), + dict(Record(string_value="Vancouver")), + ], + [ + dict(Record(long_value=50)), + dict(Record(string_value="Beta st")), + dict(Record(string_value="Toronto")), + ], + [ + dict(Record(long_value=100)), + dict(Record(string_value="Gamma av")), + dict(Record(string_value="Seattle")), + ], + ], + 3, + ) + + +def _validate_uuid(uuid): + match = re.search(r"^[a-z0-9]{8}(-[a-z0-9]{4}){3}-[a-z0-9]{12}(:\d+)?$", uuid) + if not match: + raise ValidationException( + "id must satisfy regex pattern: ^[a-z0-9]{8}(-[a-z0-9]{4}){3}-[a-z0-9]{12}(:\\d+)?$" + ) + + +# For unknown reasons I cannot use the service name "redshift-data" as I should +# It seems boto3 is unable to get the list of available regions for "redshift-data" +# See code here https://github.com/spulec/moto/blob/master/moto/core/utils.py#L407 +# sess.get_available_regions("redshift-data") returns an empty list +# Then I use the service redshift since they share the same regions +# See https://docs.aws.amazon.com/general/latest/gr/redshift-service.html +redshiftdata_backends = BackendDict(RedshiftDataAPIServiceBackend, "redshift") diff --git a/moto/redshiftdata/responses.py b/moto/redshiftdata/responses.py new file mode 100644 index 000000000..4269bd1c3 --- /dev/null +++ b/moto/redshiftdata/responses.py @@ -0,0 +1,60 @@ +import json +from moto.core.responses import BaseResponse +from .models import redshiftdata_backends + + +class RedshiftDataAPIServiceResponse(BaseResponse): + @property + def redshiftdata_backend(self): + return redshiftdata_backends[self.region] + + def cancel_statement(self): + statement_id = self._get_param("Id") + status = self.redshiftdata_backend.cancel_statement(statement_id=statement_id,) + return 200, {}, json.dumps({"Status": status}) + + def describe_statement(self): + statement_id = self._get_param("Id") + statement = self.redshiftdata_backend.describe_statement( + statement_id=statement_id, + ) + return 200, {}, json.dumps(dict(statement)) + + def execute_statement(self): + cluster_identifier = self._get_param("ClusterIdentifier") + database = self._get_param("Database") + db_user = self._get_param("DbUser") + parameters = self._get_param("Parameters") + secret_arn = self._get_param("SecretArn") + sql = self._get_param("Sql") + statement = self.redshiftdata_backend.execute_statement( + cluster_identifier=cluster_identifier, + database=database, + db_user=db_user, + parameters=parameters, + secret_arn=secret_arn, + sql=sql, + ) + + return ( + 200, + {}, + json.dumps( + { + "ClusterIdentifier": statement.cluster_identifier, + "CreatedAt": statement.created_at, + "Database": statement.database, + "DbUser": statement.db_user, + "Id": statement.id, + "SecretArn": statement.secret_arn, + } + ), + ) + + def get_statement_result(self): + statement_id = self._get_param("Id") + statement_result = self.redshiftdata_backend.get_statement_result( + statement_id=statement_id, + ) + + return 200, {}, json.dumps(dict(statement_result)) diff --git a/moto/redshiftdata/urls.py b/moto/redshiftdata/urls.py new file mode 100644 index 000000000..f1b8a3de4 --- /dev/null +++ b/moto/redshiftdata/urls.py @@ -0,0 +1,10 @@ +from .responses import RedshiftDataAPIServiceResponse + +url_bases = [ + r"https?://redshift-data\.(.+)\.amazonaws\.com", +] + + +url_paths = { + "{0}/$": RedshiftDataAPIServiceResponse.dispatch, +} diff --git a/tests/helpers.py b/tests/helpers.py index 47d1793db..34341464d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,6 @@ from collections.abc import Iterable, Mapping +from uuid import UUID + from sure import assertion @@ -32,3 +34,12 @@ def match_dict(context, dict_value): assert k in context.obj, f"No such key '{k}' in {context.obj}" context.obj[k].should.equal(v) return True + + +@assertion +def match_uuid4(context): + try: + uuid_obj = UUID(context.obj, version=4) + except ValueError: + return False + return str(uuid_obj) == context.obj diff --git a/tests/test_redshiftdata/__init__.py b/tests/test_redshiftdata/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_redshiftdata/test_redshiftdata.py b/tests/test_redshiftdata/test_redshiftdata.py new file mode 100644 index 000000000..63667cfb8 --- /dev/null +++ b/tests/test_redshiftdata/test_redshiftdata.py @@ -0,0 +1,182 @@ +import boto3 +import pytest +import sure # noqa # pylint: disable=unused-import +from botocore.exceptions import ClientError +from moto import mock_redshiftdata +from tests.test_redshiftdata.test_redshiftdata_constants import ErrorAttributes + +REGION = "us-east-1" + +INVALID_ID_ERROR_MESSAGE = ( + "id must satisfy regex pattern: ^[a-z0-9]{8}(-[a-z0-9]{4}){3}-[a-z0-9]{12}(:\\d+)?$" +) +RESOURCE_NOT_FOUND_ERROR_MESSAGE = "Query does not exist." + + +@pytest.fixture(autouse=True) +def client(): + yield boto3.client("redshift-data", region_name=REGION) + + +@mock_redshiftdata +def test_cancel_statement_throws_exception_when_uuid_invalid(client): + statement_id = "test" + + with pytest.raises(ClientError) as raised_exception: + client.cancel_statement(Id=statement_id) + + assert_expected_exception( + raised_exception, "ValidationException", INVALID_ID_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_cancel_statement_throws_exception_when_statement_not_found(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + + with pytest.raises(ClientError) as raised_exception: + client.cancel_statement(Id=statement_id) + + assert_expected_exception( + raised_exception, "ResourceNotFoundException", RESOURCE_NOT_FOUND_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_describe_statement_throws_exception_when_uuid_invalid(client): + statement_id = "test" + + with pytest.raises(ClientError) as raised_exception: + client.describe_statement(Id=statement_id) + + assert_expected_exception( + raised_exception, "ValidationException", INVALID_ID_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_describe_statement_throws_exception_when_statement_not_found(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + + with pytest.raises(ClientError) as raised_exception: + client.describe_statement(Id=statement_id) + + assert_expected_exception( + raised_exception, "ResourceNotFoundException", RESOURCE_NOT_FOUND_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_get_statement_result_throws_exception_when_uuid_invalid(client): + statement_id = "test" + + with pytest.raises(ClientError) as raised_exception: + client.get_statement_result(Id=statement_id) + + assert_expected_exception( + raised_exception, "ValidationException", INVALID_ID_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_get_statement_result_throws_exception_when_statement_not_found(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + + with pytest.raises(ClientError) as raised_exception: + client.get_statement_result(Id=statement_id) + + assert_expected_exception( + raised_exception, "ResourceNotFoundException", RESOURCE_NOT_FOUND_ERROR_MESSAGE + ) + + +@mock_redshiftdata +def test_execute_statement_and_cancel_statement(client): + cluster_identifier = "cluster_identifier" + database = "database" + db_user = "db_user" + parameters = [{"name": "name", "value": "value"}] + secret_arn = "secret_arn" + sql = "sql" + + # Execute statement + execute_response = client.execute_statement( + ClusterIdentifier=cluster_identifier, + Database=database, + DbUser=db_user, + Parameters=parameters, + SecretArn=secret_arn, + Sql=sql, + ) + + # Cancel statement + cancel_response = client.cancel_statement(Id=execute_response["Id"]) + + cancel_response["Status"].should.equal(True) + + +@mock_redshiftdata +def test_execute_statement_and_describe_statement(client): + cluster_identifier = "cluster_identifier" + database = "database" + db_user = "db_user" + parameters = [{"name": "name", "value": "value"}] + secret_arn = "secret_arn" + sql = "sql" + + # Execute statement + execute_response = client.execute_statement( + ClusterIdentifier=cluster_identifier, + Database=database, + DbUser=db_user, + Parameters=parameters, + SecretArn=secret_arn, + Sql=sql, + ) + + # Describe statement + describe_response = client.describe_statement(Id=execute_response["Id"]) + + describe_response["ClusterIdentifier"].should.equal(cluster_identifier) + describe_response["Database"].should.equal(database) + describe_response["DbUser"].should.equal(db_user) + describe_response["QueryParameters"].should.equal(parameters) + describe_response["SecretArn"].should.equal(secret_arn) + describe_response["QueryString"].should.equal(sql) + describe_response["Status"].should.equal("STARTED") + + +@mock_redshiftdata +def test_execute_statement_and_get_statement_result(client): + cluster_identifier = "cluster_identifier" + database = "database" + db_user = "db_user" + parameters = [{"name": "name", "value": "value"}] + secret_arn = "secret_arn" + sql = "sql" + + # Execute statement + execute_response = client.execute_statement( + ClusterIdentifier=cluster_identifier, + Database=database, + DbUser=db_user, + Parameters=parameters, + SecretArn=secret_arn, + Sql=sql, + ) + + # Get statement result + result_response = client.get_statement_result(Id=execute_response["Id"]) + + result_response["ColumnMetadata"][0]["name"].should.equal("Number") + result_response["ColumnMetadata"][1]["name"].should.equal("Street") + result_response["ColumnMetadata"][2]["name"].should.equal("City") + result_response["Records"][0][0]["longValue"].should.equal(10) + result_response["Records"][1][1]["stringValue"].should.equal("Beta st") + result_response["Records"][2][2]["stringValue"].should.equal("Seattle") + + +def assert_expected_exception(raised_exception, expected_exception, expected_message): + error = raised_exception.value.response[ErrorAttributes.ERROR] + error[ErrorAttributes.CODE].should.equal(expected_exception) + error[ErrorAttributes.MESSAGE].should.equal(expected_message) diff --git a/tests/test_redshiftdata/test_redshiftdata_constants.py b/tests/test_redshiftdata/test_redshiftdata_constants.py new file mode 100644 index 000000000..f42bfe954 --- /dev/null +++ b/tests/test_redshiftdata/test_redshiftdata_constants.py @@ -0,0 +1,11 @@ +class ErrorAttributes: + CODE = "Code" + ERROR = "Error" + MESSAGE = "Message" + + +DEFAULT_ENCODING = "utf-8" + + +class HttpHeaders: + ErrorType = "x-amzn-ErrorType" diff --git a/tests/test_redshiftdata/test_server.py b/tests/test_redshiftdata/test_server.py new file mode 100644 index 000000000..87317e30e --- /dev/null +++ b/tests/test_redshiftdata/test_server.py @@ -0,0 +1,233 @@ +"""Test different server responses.""" +import json +import pytest +import sure # noqa # pylint: disable=unused-import +import moto.server as server +from tests.test_redshiftdata.test_redshiftdata_constants import ( + DEFAULT_ENCODING, + HttpHeaders, +) + +CLIENT_ENDPOINT = "/" + + +def headers(action): + return { + "X-Amz-Target": "RedshiftData.%s" % action, + "Content-Type": "application/x-amz-json-1.1", + } + + +@pytest.fixture(autouse=True) +def client(): + backend = server.create_backend_app("redshift-data") + yield backend.test_client() + + +def test_redshiftdata_cancel_statement_unknown_statement(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": statement_id}), + headers=headers("CancelStatement"), + ) + response.status_code.should.equal(400) + should_return_expected_exception( + response, "ResourceNotFoundException", "Query does not exist." + ) + + +def test_redshiftdata_describe_statement_unknown_statement(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": statement_id}), + headers=headers("DescribeStatement"), + ) + response.status_code.should.equal(400) + should_return_expected_exception( + response, "ResourceNotFoundException", "Query does not exist." + ) + + +def test_redshiftdata_get_statement_result_unknown_statement(client): + statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0" + response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": statement_id}), + headers=headers("GetStatementResult"), + ) + response.status_code.should.equal(400) + should_return_expected_exception( + response, "ResourceNotFoundException", "Query does not exist." + ) + + +def test_redshiftdata_execute_statement_with_minimal_values(client): + database = "database" + sql = "sql" + response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Database": database, "Sql": sql}), + headers=headers("ExecuteStatement"), + ) + response.status_code.should.equal(200) + payload = get_payload(response) + + payload["ClusterIdentifier"].should.equal(None) + payload["Database"].should.equal(database) + payload["DbUser"].should.equal(None) + payload["SecretArn"].should.equal(None) + payload["Id"].should.match_uuid4() + + +def test_redshiftdata_execute_statement_with_all_values(client): + cluster = "cluster" + database = "database" + dbUser = "dbUser" + sql = "sql" + secretArn = "secretArn" + + response = client.post( + CLIENT_ENDPOINT, + data=json.dumps( + { + "ClusterIdentifier": cluster, + "Database": database, + "DbUser": dbUser, + "Sql": sql, + "SecretArn": secretArn, + } + ), + headers=headers("ExecuteStatement"), + ) + response.status_code.should.equal(200) + payload = get_payload(response) + + payload["ClusterIdentifier"].should.equal(cluster) + payload["Database"].should.equal(database) + payload["DbUser"].should.equal(dbUser) + payload["SecretArn"].should.equal(secretArn) + + +def test_redshiftdata_execute_statement_and_describe_statement(client): + cluster = "cluster" + database = "database" + dbUser = "dbUser" + sql = "sql" + secretArn = "secretArn" + + # ExecuteStatement + execute_response = client.post( + CLIENT_ENDPOINT, + data=json.dumps( + { + "ClusterIdentifier": cluster, + "Database": database, + "DbUser": dbUser, + "Sql": sql, + "SecretArn": secretArn, + } + ), + headers=headers("ExecuteStatement"), + ) + execute_response.status_code.should.equal(200) + execute_payload = get_payload(execute_response) + + # DescribeStatement + describe_response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": execute_payload["Id"]}), + headers=headers("DescribeStatement"), + ) + describe_response.status_code.should.equal(200) + describe_payload = get_payload(execute_response) + + describe_payload["ClusterIdentifier"].should.equal(cluster) + describe_payload["Database"].should.equal(database) + describe_payload["DbUser"].should.equal(dbUser) + describe_payload["SecretArn"].should.equal(secretArn) + + +def test_redshiftdata_execute_statement_and_get_statement_result(client): + database = "database" + sql = "sql" + + # ExecuteStatement + execute_response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Database": database, "Sql": sql,}), + headers=headers("ExecuteStatement"), + ) + execute_response.status_code.should.equal(200) + execute_payload = get_payload(execute_response) + + # GetStatementResult + statement_result_response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": execute_payload["Id"]}), + headers=headers("GetStatementResult"), + ) + statement_result_response.status_code.should.equal(200) + statement_result_payload = get_payload(statement_result_response) + statement_result_payload["TotalNumberRows"].should.equal(3) + + # columns + len(statement_result_payload["ColumnMetadata"]).should.equal(3) + statement_result_payload["ColumnMetadata"][0]["name"].should.equal("Number") + statement_result_payload["ColumnMetadata"][1]["name"].should.equal("Street") + statement_result_payload["ColumnMetadata"][2]["name"].should.equal("City") + + # records + len(statement_result_payload["Records"]).should.equal(3) + statement_result_payload["Records"][0][0]["longValue"].should.equal(10) + statement_result_payload["Records"][1][1]["stringValue"].should.equal("Beta st") + statement_result_payload["Records"][2][2]["stringValue"].should.equal("Seattle") + + +def test_redshiftdata_execute_statement_and_cancel_statement(client): + database = "database" + sql = "sql" + + # ExecuteStatement + execute_response = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Database": database, "Sql": sql,}), + headers=headers("ExecuteStatement"), + ) + execute_response.status_code.should.equal(200) + execute_payload = get_payload(execute_response) + + # CancelStatement 1 + cancel_response1 = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": execute_payload["Id"]}), + headers=headers("CancelStatement"), + ) + cancel_response1.status_code.should.equal(200) + cancel_payload1 = get_payload(cancel_response1) + cancel_payload1["Status"].should.equal(True) + + # CancelStatement 2 + cancel_response2 = client.post( + CLIENT_ENDPOINT, + data=json.dumps({"Id": execute_payload["Id"]}), + headers=headers("CancelStatement"), + ) + cancel_response2.status_code.should.equal(400) + should_return_expected_exception( + cancel_response2, + "ValidationException", + "Could not cancel a query that is already in %s state with ID: %s" + % ("ABORTED", execute_payload["Id"]), + ) + + +def get_payload(response): + return json.loads(response.data.decode(DEFAULT_ENCODING)) + + +def should_return_expected_exception(response, expected_exception, message): + result_data = get_payload(response) + response.headers.get(HttpHeaders.ErrorType).should.equal(expected_exception) + result_data["message"].should.equal(message)