diff --git a/moto/athena/models.py b/moto/athena/models.py index 78eee9313..4fec79325 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -42,7 +42,7 @@ class WorkGroup(TaggableResourceMixin, BaseModel): self, athena_backend: "AthenaBackend", name: str, - configuration: str, + configuration: Dict[str, Any], description: str, tags: List[Dict[str, str]], ): @@ -161,7 +161,9 @@ class AthenaBackend(BaseBackend): self.prepared_statements: Dict[str, PreparedStatement] = {} # Initialise with the primary workgroup - self.create_work_group("primary", "", "", []) + self.create_work_group( + name="primary", description="", configuration=dict(), tags=[] + ) @staticmethod def default_vpc_endpoint_service( @@ -175,7 +177,7 @@ class AthenaBackend(BaseBackend): def create_work_group( self, name: str, - configuration: str, + configuration: Dict[str, Any], description: str, tags: List[Dict[str, str]], ) -> Optional[WorkGroup]: diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 1389511ef..12bc320bd 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -12,7 +12,7 @@ from moto.core import DEFAULT_ACCOUNT_ID def test_create_work_group(): client = boto3.client("athena", region_name="us-east-1") - response = client.create_work_group( + client.create_work_group( Name="athena_workgroup", Description="Test work group", Configuration={ @@ -24,12 +24,11 @@ def test_create_work_group(): }, } }, - Tags=[], ) try: # The second time should throw an error - response = client.create_work_group( + client.create_work_group( Name="athena_workgroup", Description="duplicate", Configuration={ @@ -61,6 +60,16 @@ def test_create_work_group(): work_group["State"].should.equal("ENABLED") +@mock_athena +def test_get_primary_workgroup(): + client = boto3.client("athena", region_name="us-east-1") + assert len(client.list_work_groups()["WorkGroups"]) == 1 + + primary = client.get_work_group(WorkGroup="primary")["WorkGroup"] + assert primary["Name"] == "primary" + assert primary["Configuration"] == {} + + @mock_athena def test_create_and_get_workgroup(): client = boto3.client("athena", region_name="us-east-1")