234 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			234 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								"""Test different server responses."""
							 | 
						||
| 
								 | 
							
								import json
							 | 
						||
| 
								 | 
							
								import pytest
							 | 
						||
| 
								 | 
							
								import sure  # noqa # pylint: disable=unused-import
							 | 
						||
| 
								 | 
							
								import moto.server as server
							 | 
						||
| 
								 | 
							
								from tests.test_redshiftdata.test_redshiftdata_constants import (
							 | 
						||
| 
								 | 
							
								    DEFAULT_ENCODING,
							 | 
						||
| 
								 | 
							
								    HttpHeaders,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								CLIENT_ENDPOINT = "/"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def headers(action):
							 | 
						||
| 
								 | 
							
								    return {
							 | 
						||
| 
								 | 
							
								        "X-Amz-Target": "RedshiftData.%s" % action,
							 | 
						||
| 
								 | 
							
								        "Content-Type": "application/x-amz-json-1.1",
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@pytest.fixture(autouse=True)
							 | 
						||
| 
								 | 
							
								def client():
							 | 
						||
| 
								 | 
							
								    backend = server.create_backend_app("redshift-data")
							 | 
						||
| 
								 | 
							
								    yield backend.test_client()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_cancel_statement_unknown_statement(client):
							 | 
						||
| 
								 | 
							
								    statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0"
							 | 
						||
| 
								 | 
							
								    response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": statement_id}),
							 | 
						||
| 
								 | 
							
								        headers=headers("CancelStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response.status_code.should.equal(400)
							 | 
						||
| 
								 | 
							
								    should_return_expected_exception(
							 | 
						||
| 
								 | 
							
								        response, "ResourceNotFoundException", "Query does not exist."
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_describe_statement_unknown_statement(client):
							 | 
						||
| 
								 | 
							
								    statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0"
							 | 
						||
| 
								 | 
							
								    response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": statement_id}),
							 | 
						||
| 
								 | 
							
								        headers=headers("DescribeStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response.status_code.should.equal(400)
							 | 
						||
| 
								 | 
							
								    should_return_expected_exception(
							 | 
						||
| 
								 | 
							
								        response, "ResourceNotFoundException", "Query does not exist."
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_get_statement_result_unknown_statement(client):
							 | 
						||
| 
								 | 
							
								    statement_id = "890f1253-595b-4608-a0d1-73f933ccd0a0"
							 | 
						||
| 
								 | 
							
								    response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": statement_id}),
							 | 
						||
| 
								 | 
							
								        headers=headers("GetStatementResult"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response.status_code.should.equal(400)
							 | 
						||
| 
								 | 
							
								    should_return_expected_exception(
							 | 
						||
| 
								 | 
							
								        response, "ResourceNotFoundException", "Query does not exist."
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_execute_statement_with_minimal_values(client):
							 | 
						||
| 
								 | 
							
								    database = "database"
							 | 
						||
| 
								 | 
							
								    sql = "sql"
							 | 
						||
| 
								 | 
							
								    response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Database": database, "Sql": sql}),
							 | 
						||
| 
								 | 
							
								        headers=headers("ExecuteStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    payload = get_payload(response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    payload["ClusterIdentifier"].should.equal(None)
							 | 
						||
| 
								 | 
							
								    payload["Database"].should.equal(database)
							 | 
						||
| 
								 | 
							
								    payload["DbUser"].should.equal(None)
							 | 
						||
| 
								 | 
							
								    payload["SecretArn"].should.equal(None)
							 | 
						||
| 
								 | 
							
								    payload["Id"].should.match_uuid4()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_execute_statement_with_all_values(client):
							 | 
						||
| 
								 | 
							
								    cluster = "cluster"
							 | 
						||
| 
								 | 
							
								    database = "database"
							 | 
						||
| 
								 | 
							
								    dbUser = "dbUser"
							 | 
						||
| 
								 | 
							
								    sql = "sql"
							 | 
						||
| 
								 | 
							
								    secretArn = "secretArn"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps(
							 | 
						||
| 
								 | 
							
								            {
							 | 
						||
| 
								 | 
							
								                "ClusterIdentifier": cluster,
							 | 
						||
| 
								 | 
							
								                "Database": database,
							 | 
						||
| 
								 | 
							
								                "DbUser": dbUser,
							 | 
						||
| 
								 | 
							
								                "Sql": sql,
							 | 
						||
| 
								 | 
							
								                "SecretArn": secretArn,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        ),
							 | 
						||
| 
								 | 
							
								        headers=headers("ExecuteStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    payload = get_payload(response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    payload["ClusterIdentifier"].should.equal(cluster)
							 | 
						||
| 
								 | 
							
								    payload["Database"].should.equal(database)
							 | 
						||
| 
								 | 
							
								    payload["DbUser"].should.equal(dbUser)
							 | 
						||
| 
								 | 
							
								    payload["SecretArn"].should.equal(secretArn)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_execute_statement_and_describe_statement(client):
							 | 
						||
| 
								 | 
							
								    cluster = "cluster"
							 | 
						||
| 
								 | 
							
								    database = "database"
							 | 
						||
| 
								 | 
							
								    dbUser = "dbUser"
							 | 
						||
| 
								 | 
							
								    sql = "sql"
							 | 
						||
| 
								 | 
							
								    secretArn = "secretArn"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # ExecuteStatement
							 | 
						||
| 
								 | 
							
								    execute_response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps(
							 | 
						||
| 
								 | 
							
								            {
							 | 
						||
| 
								 | 
							
								                "ClusterIdentifier": cluster,
							 | 
						||
| 
								 | 
							
								                "Database": database,
							 | 
						||
| 
								 | 
							
								                "DbUser": dbUser,
							 | 
						||
| 
								 | 
							
								                "Sql": sql,
							 | 
						||
| 
								 | 
							
								                "SecretArn": secretArn,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        ),
							 | 
						||
| 
								 | 
							
								        headers=headers("ExecuteStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    execute_response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    execute_payload = get_payload(execute_response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # DescribeStatement
							 | 
						||
| 
								 | 
							
								    describe_response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": execute_payload["Id"]}),
							 | 
						||
| 
								 | 
							
								        headers=headers("DescribeStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    describe_response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    describe_payload = get_payload(execute_response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    describe_payload["ClusterIdentifier"].should.equal(cluster)
							 | 
						||
| 
								 | 
							
								    describe_payload["Database"].should.equal(database)
							 | 
						||
| 
								 | 
							
								    describe_payload["DbUser"].should.equal(dbUser)
							 | 
						||
| 
								 | 
							
								    describe_payload["SecretArn"].should.equal(secretArn)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_execute_statement_and_get_statement_result(client):
							 | 
						||
| 
								 | 
							
								    database = "database"
							 | 
						||
| 
								 | 
							
								    sql = "sql"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # ExecuteStatement
							 | 
						||
| 
								 | 
							
								    execute_response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Database": database, "Sql": sql,}),
							 | 
						||
| 
								 | 
							
								        headers=headers("ExecuteStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    execute_response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    execute_payload = get_payload(execute_response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # GetStatementResult
							 | 
						||
| 
								 | 
							
								    statement_result_response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": execute_payload["Id"]}),
							 | 
						||
| 
								 | 
							
								        headers=headers("GetStatementResult"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    statement_result_response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    statement_result_payload = get_payload(statement_result_response)
							 | 
						||
| 
								 | 
							
								    statement_result_payload["TotalNumberRows"].should.equal(3)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # columns
							 | 
						||
| 
								 | 
							
								    len(statement_result_payload["ColumnMetadata"]).should.equal(3)
							 | 
						||
| 
								 | 
							
								    statement_result_payload["ColumnMetadata"][0]["name"].should.equal("Number")
							 | 
						||
| 
								 | 
							
								    statement_result_payload["ColumnMetadata"][1]["name"].should.equal("Street")
							 | 
						||
| 
								 | 
							
								    statement_result_payload["ColumnMetadata"][2]["name"].should.equal("City")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # records
							 | 
						||
| 
								 | 
							
								    len(statement_result_payload["Records"]).should.equal(3)
							 | 
						||
| 
								 | 
							
								    statement_result_payload["Records"][0][0]["longValue"].should.equal(10)
							 | 
						||
| 
								 | 
							
								    statement_result_payload["Records"][1][1]["stringValue"].should.equal("Beta st")
							 | 
						||
| 
								 | 
							
								    statement_result_payload["Records"][2][2]["stringValue"].should.equal("Seattle")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_redshiftdata_execute_statement_and_cancel_statement(client):
							 | 
						||
| 
								 | 
							
								    database = "database"
							 | 
						||
| 
								 | 
							
								    sql = "sql"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # ExecuteStatement
							 | 
						||
| 
								 | 
							
								    execute_response = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Database": database, "Sql": sql,}),
							 | 
						||
| 
								 | 
							
								        headers=headers("ExecuteStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    execute_response.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    execute_payload = get_payload(execute_response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # CancelStatement 1
							 | 
						||
| 
								 | 
							
								    cancel_response1 = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": execute_payload["Id"]}),
							 | 
						||
| 
								 | 
							
								        headers=headers("CancelStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    cancel_response1.status_code.should.equal(200)
							 | 
						||
| 
								 | 
							
								    cancel_payload1 = get_payload(cancel_response1)
							 | 
						||
| 
								 | 
							
								    cancel_payload1["Status"].should.equal(True)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # CancelStatement 2
							 | 
						||
| 
								 | 
							
								    cancel_response2 = client.post(
							 | 
						||
| 
								 | 
							
								        CLIENT_ENDPOINT,
							 | 
						||
| 
								 | 
							
								        data=json.dumps({"Id": execute_payload["Id"]}),
							 | 
						||
| 
								 | 
							
								        headers=headers("CancelStatement"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    cancel_response2.status_code.should.equal(400)
							 | 
						||
| 
								 | 
							
								    should_return_expected_exception(
							 | 
						||
| 
								 | 
							
								        cancel_response2,
							 | 
						||
| 
								 | 
							
								        "ValidationException",
							 | 
						||
| 
								 | 
							
								        "Could not cancel a query that is already in %s state with ID: %s"
							 | 
						||
| 
								 | 
							
								        % ("ABORTED", execute_payload["Id"]),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_payload(response):
							 | 
						||
| 
								 | 
							
								    return json.loads(response.data.decode(DEFAULT_ENCODING))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def should_return_expected_exception(response, expected_exception, message):
							 | 
						||
| 
								 | 
							
								    result_data = get_payload(response)
							 | 
						||
| 
								 | 
							
								    response.headers.get(HttpHeaders.ErrorType).should.equal(expected_exception)
							 | 
						||
| 
								 | 
							
								    result_data["message"].should.equal(message)
							 |