diff --git a/moto/glue/models.py b/moto/glue/models.py index 738f741c9..26cfbbc0e 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -2,7 +2,7 @@ import time from collections import OrderedDict from datetime import datetime import re -from typing import List +from typing import Dict, List from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api import state_manager @@ -76,7 +76,7 @@ class GlueBackend(BaseBackend): self.jobs = OrderedDict() self.job_runs = OrderedDict() self.tagger = TaggingService() - self.registries = OrderedDict() + self.registries: Dict[str, FakeRegistry] = OrderedDict() self.num_schemas = 0 self.num_schema_versions = 0 @@ -323,17 +323,28 @@ class GlueBackend(BaseBackend): def create_registry(self, registry_name, description=None, tags=None): # If registry name id default-registry, create default-registry if registry_name == DEFAULT_REGISTRY_NAME: - registry = FakeRegistry(self.account_id, registry_name, description, tags) + registry = FakeRegistry(self, registry_name, description, tags) self.registries[registry_name] = registry return registry # Validate Registry Parameters validate_registry_params(self.registries, registry_name, description, tags) - registry = FakeRegistry(self.account_id, registry_name, description, tags) + registry = FakeRegistry(self, registry_name, description, tags) self.registries[registry_name] = registry return registry.as_dict() + def delete_registry(self, registry_id): + registry_name = validate_registry_id(registry_id, self.registries) + return self.registries.pop(registry_name).as_dict() + + def get_registry(self, registry_id): + registry_name = validate_registry_id(registry_id, self.registries) + return self.registries[registry_name].as_dict() + + def list_registries(self): + return [reg.as_dict() for reg in self.registries.values()] + def create_schema( self, registry_id, @@ -371,7 +382,7 @@ class GlueBackend(BaseBackend): # Create Schema schema_version = FakeSchemaVersion( - self.account_id, + self, registry_name, schema_name, schema_definition, @@ -379,14 +390,16 @@ class GlueBackend(BaseBackend): ) schema_version_id = schema_version.get_schema_version_id() schema = FakeSchema( - self.account_id, + self, registry_name, schema_name, data_format, compatibility, schema_version_id, description, - tags, + ) + self.tagger.tag_resource( + schema.schema_arn, self.tagger.convert_dict_to_tags_input(tags) ) registry.schemas[schema_name] = schema self.num_schemas += 1 @@ -394,7 +407,10 @@ class GlueBackend(BaseBackend): schema.schema_versions[schema.schema_version_id] = schema_version self.num_schema_versions += 1 - return schema.as_dict() + resp = schema.as_dict() + if tags: + resp.update({"Tags": tags}) + return resp def register_schema_version(self, schema_id, schema_definition): # Validate Schema Id @@ -442,7 +458,7 @@ class GlueBackend(BaseBackend): self.num_schema_versions += 1 schema_version = FakeSchemaVersion( - self.account_id, + self, registry_name, schema_name, schema_definition, @@ -594,6 +610,11 @@ class GlueBackend(BaseBackend): registry_name, schema_name, schema_arn, version_number, latest_version ) + def get_schema(self, schema_id): + registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) + schema = self.registries[registry_name].schemas[schema_name] + return schema.as_dict() + def delete_schema(self, schema_id): # Validate schema_id registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) @@ -613,6 +634,20 @@ class GlueBackend(BaseBackend): return response + def update_schema(self, schema_id, compatibility, description): + """ + The SchemaVersionNumber-argument is not yet implemented + """ + registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) + schema = self.registries[registry_name].schemas[schema_name] + + if compatibility is not None: + schema.compatibility = compatibility + if description is not None: + schema.description = description + + return schema.as_dict() + def batch_delete_table(self, database_name, tables): errors = [] for table_name in tables: @@ -856,7 +891,7 @@ class FakeCrawler(BaseModel): self.version = 1 self.crawl_elapsed_time = 0 self.last_crawl_info = None - self.arn = f"arn:aws:glue:us-east-1:{backend.account_id}:crawler/{self.name}" + self.arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:crawler/{self.name}" self.backend = backend self.backend.tag_resource(self.arn, tags) @@ -978,7 +1013,9 @@ class FakeJob: self.worker_type = worker_type self.created_on = datetime.utcnow() self.last_modified_on = datetime.utcnow() - self.arn = f"arn:aws:glue:us-east-1:{backend.account_id}:job/{self.name}" + self.arn = ( + f"arn:aws:glue:{backend.region_name}:{backend.account_id}:job/{self.name}" + ) self.backend = backend self.backend.tag_resource(self.arn, tags) @@ -1089,15 +1126,15 @@ class FakeJobRun(ManagedState): class FakeRegistry(BaseModel): - def __init__(self, account_id, registry_name, description=None, tags=None): + def __init__(self, backend, registry_name, description=None, tags=None): self.name = registry_name self.description = description self.tags = tags self.created_time = datetime.utcnow() self.updated_time = datetime.utcnow() self.status = "AVAILABLE" - self.registry_arn = f"arn:aws:glue:us-east-1:{account_id}:registry/{self.name}" - self.schemas = OrderedDict() + self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.name}" + self.schemas: Dict[str, FakeSchema] = OrderedDict() def as_dict(self): return { @@ -1111,21 +1148,18 @@ class FakeRegistry(BaseModel): class FakeSchema(BaseModel): def __init__( self, - account_id, + backend: GlueBackend, registry_name, schema_name, data_format, compatibility, schema_version_id, description=None, - tags=None, ): self.registry_name = registry_name - self.registry_arn = ( - f"arn:aws:glue:us-east-1:{account_id}:registry/{self.registry_name}" - ) + self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.registry_name}" self.schema_name = schema_name - self.schema_arn = f"arn:aws:glue:us-east-1:{account_id}:schema/{self.registry_name}/{self.schema_name}" + self.schema_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:schema/{self.registry_name}/{self.schema_name}" self.description = description self.data_format = data_format self.compatibility = compatibility @@ -1133,7 +1167,6 @@ class FakeSchema(BaseModel): self.latest_schema_version = 1 self.next_schema_version = 2 self.schema_status = AVAILABLE_STATUS - self.tags = tags self.schema_version_id = schema_version_id self.schema_version_status = AVAILABLE_STATUS self.created_time = datetime.utcnow() @@ -1164,17 +1197,21 @@ class FakeSchema(BaseModel): "SchemaVersionId": self.schema_version_id, "SchemaVersionStatus": self.schema_version_status, "Description": self.description, - "Tags": self.tags, } class FakeSchemaVersion(BaseModel): def __init__( - self, account_id, registry_name, schema_name, schema_definition, version_number + self, + backend: GlueBackend, + registry_name, + schema_name, + schema_definition, + version_number, ): self.registry_name = registry_name self.schema_name = schema_name - self.schema_arn = f"arn:aws:glue:us-east-1:{account_id}:schema/{self.registry_name}/{self.schema_name}" + self.schema_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:schema/{self.registry_name}/{self.schema_name}" self.schema_definition = schema_definition self.schema_version_status = AVAILABLE_STATUS self.version_number = version_number @@ -1214,6 +1251,4 @@ class FakeSchemaVersion(BaseModel): } -glue_backends = BackendDict( - GlueBackend, "glue", use_boto3_regions=False, additional_regions=["global"] -) +glue_backends = BackendDict(GlueBackend, "glue") diff --git a/moto/glue/responses.py b/moto/glue/responses.py index e4eea8031..486e0bee7 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -1,7 +1,7 @@ import json from moto.core.responses import BaseResponse -from .models import glue_backends +from .models import glue_backends, GlueBackend class GlueResponse(BaseResponse): @@ -9,8 +9,8 @@ class GlueResponse(BaseResponse): super().__init__(service_name="glue") @property - def glue_backend(self): - return glue_backends[self.current_account]["global"] + def glue_backend(self) -> GlueBackend: + return glue_backends[self.current_account][self.region] @property def parameters(self): @@ -413,6 +413,20 @@ class GlueResponse(BaseResponse): registry = self.glue_backend.create_registry(registry_name, description, tags) return json.dumps(registry) + def delete_registry(self): + registry_id = self._get_param("RegistryId") + registry = self.glue_backend.delete_registry(registry_id) + return json.dumps(registry) + + def get_registry(self): + registry_id = self._get_param("RegistryId") + registry = self.glue_backend.get_registry(registry_id) + return json.dumps(registry) + + def list_registries(self): + registries = self.glue_backend.list_registries() + return json.dumps({"Registries": registries}) + def create_schema(self): registry_id = self._get_param("RegistryId") schema_name = self._get_param("SchemaName") @@ -468,7 +482,19 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema_version) + def get_schema(self): + schema_id = self._get_param("SchemaId") + schema = self.glue_backend.get_schema(schema_id) + return json.dumps(schema) + def delete_schema(self): schema_id = self._get_param("SchemaId") schema = self.glue_backend.delete_schema(schema_id) return json.dumps(schema) + + def update_schema(self): + schema_id = self._get_param("SchemaId") + compatibility = self._get_param("Compatibility") + description = self._get_param("Description") + schema = self.glue_backend.update_schema(schema_id, compatibility, description) + return json.dumps(schema) diff --git a/tests/terraformtests/terraform-tests.success.txt b/tests/terraformtests/terraform-tests.success.txt index edba3828f..01e53f602 100644 --- a/tests/terraformtests/terraform-tests.success.txt +++ b/tests/terraformtests/terraform-tests.success.txt @@ -200,6 +200,8 @@ events: - TestAccEventsConnection - TestAccEventsConnectionDataSource - TestAccEventsPermission +glue: + - TestAccGlueSchema_ guardduty: - TestAccGuardDuty_serial/Detector/basic - TestAccGuardDuty_serial/Filter/basic diff --git a/tests/test_glue/test_schema_registry.py b/tests/test_glue/test_schema_registry.py index edc0628ad..fb6040117 100644 --- a/tests/test_glue/test_schema_registry.py +++ b/tests/test_glue/test_schema_registry.py @@ -1204,6 +1204,36 @@ def test_put_schema_version_metadata_invalid_characters_metadata_value_schema_ve ) +def test_get_schema(client): + helpers.create_registry(client) + helpers.create_schema(client, TEST_REGISTRY_ID) + + response = client.get_schema( + SchemaId={"RegistryName": TEST_REGISTRY_NAME, "SchemaName": TEST_SCHEMA_NAME} + ) + response.should.have.key("SchemaArn").equals(TEST_SCHEMA_ARN) + response.should.have.key("SchemaName").equals(TEST_SCHEMA_NAME) + + +def test_update_schema(client): + helpers.create_registry(client) + helpers.create_schema(client, TEST_REGISTRY_ID) + + client.update_schema( + SchemaId={"RegistryName": TEST_REGISTRY_NAME, "SchemaName": TEST_SCHEMA_NAME}, + Compatibility="FORWARD", + Description="updated schema", + ) + + response = client.get_schema( + SchemaId={"RegistryName": TEST_REGISTRY_NAME, "SchemaName": TEST_SCHEMA_NAME} + ) + response.should.have.key("SchemaArn").equals(TEST_SCHEMA_ARN) + response.should.have.key("SchemaName").equals(TEST_SCHEMA_NAME) + response.should.have.key("Description").equals("updated schema") + response.should.have.key("Compatibility").equals("FORWARD") + + # test delete_schema def test_delete_schema_valid_input(client): helpers.create_registry(client) @@ -1257,3 +1287,28 @@ def test_delete_schema_schema_not_found(client): err["Message"].should.have( f"Schema is not found. RegistryName: {TEST_REGISTRY_NAME}, SchemaName: {TEST_SCHEMA_NAME}, SchemaArn: null" ) + + +def test_list_registries(client): + helpers.create_registry(client) + helpers.create_registry(client, registry_name="registry2") + + registries = client.list_registries()["Registries"] + registries.should.have.length_of(2) + + +@pytest.mark.parametrize("name_or_arn", ["RegistryArn", "RegistryName"]) +def test_get_registry(client, name_or_arn): + x = helpers.create_registry(client) + + r = client.get_registry(RegistryId={name_or_arn: x[name_or_arn]}) + r.should.have.key("RegistryName").equals(x["RegistryName"]) + r.should.have.key("RegistryArn").equals(x["RegistryArn"]) + + +@pytest.mark.parametrize("name_or_arn", ["RegistryArn", "RegistryName"]) +def test_delete_registry(client, name_or_arn): + x = helpers.create_registry(client) + + client.delete_registry(RegistryId={name_or_arn: x[name_or_arn]}) + client.list_registries()["Registries"].should.have.length_of(0)